snowflake-ml-python 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. snowflake/ml/_internal/env_utils.py +16 -0
  2. snowflake/ml/_internal/telemetry.py +56 -7
  3. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +1 -7
  4. snowflake/ml/experiment/_entities/run.py +15 -0
  5. snowflake/ml/experiment/experiment_tracking.py +61 -73
  6. snowflake/ml/feature_store/access_manager.py +1 -0
  7. snowflake/ml/feature_store/feature_store.py +86 -31
  8. snowflake/ml/feature_store/feature_view.py +12 -6
  9. snowflake/ml/fileset/stage_fs.py +12 -1
  10. snowflake/ml/jobs/_utils/feature_flags.py +1 -0
  11. snowflake/ml/jobs/_utils/payload_utils.py +6 -1
  12. snowflake/ml/jobs/_utils/spec_utils.py +12 -3
  13. snowflake/ml/jobs/job.py +8 -3
  14. snowflake/ml/jobs/manager.py +19 -6
  15. snowflake/ml/model/_client/model/inference_engine_utils.py +8 -4
  16. snowflake/ml/model/_client/model/model_version_impl.py +45 -17
  17. snowflake/ml/model/_client/ops/model_ops.py +11 -4
  18. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  19. snowflake/ml/model/models/huggingface_pipeline.py +6 -7
  20. snowflake/ml/monitoring/explain_visualize.py +3 -1
  21. snowflake/ml/version.py +1 -1
  22. {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/METADATA +68 -5
  23. {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/RECORD +26 -26
  24. {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/WHEEL +0 -0
  25. {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/licenses/LICENSE.txt +0 -0
  26. {snowflake_ml_python-1.18.0.dist-info → snowflake_ml_python-1.20.0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from snowflake.ml import version as snowml_version
16
16
  from snowflake.ml._internal import env as snowml_env, relax_version_strategy
17
17
  from snowflake.ml._internal.utils import query_result_checker
18
18
  from snowflake.snowpark import context, exceptions, session
19
+ from snowflake.snowpark._internal import utils as snowpark_utils
19
20
 
20
21
 
21
22
  class CONDA_OS(Enum):
@@ -38,6 +39,21 @@ SNOWPARK_ML_PKG_NAME = "snowflake-ml-python"
38
39
  SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
39
40
 
40
41
 
42
+ def get_execution_context() -> str:
43
+ """Detect execution context: EXTERNAL, SPCS, or SPROC.
44
+
45
+ Returns:
46
+ str: The execution context - "SPROC" if running in a stored procedure,
47
+ "SPCS" if running in SPCS ML runtime, "EXTERNAL" otherwise.
48
+ """
49
+ if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
50
+ return "SPROC"
51
+ elif snowml_env.IN_ML_RUNTIME:
52
+ return "SPCS"
53
+ else:
54
+ return "EXTERNAL"
55
+
56
+
41
57
  def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
42
58
  """Validate the input pip requirement string according to PEP 508.
43
59
 
@@ -16,7 +16,7 @@ from typing_extensions import ParamSpec
16
16
  from snowflake import connector
17
17
  from snowflake.connector import connect, telemetry as connector_telemetry, time_util
18
18
  from snowflake.ml import version as snowml_version
19
- from snowflake.ml._internal import env
19
+ from snowflake.ml._internal import env, env_utils
20
20
  from snowflake.ml._internal.exceptions import (
21
21
  error_codes,
22
22
  exceptions as snowml_exceptions,
@@ -37,6 +37,22 @@ _CONNECTION_TYPES = {
37
37
  _Args = ParamSpec("_Args")
38
38
  _ReturnValue = TypeVar("_ReturnValue")
39
39
 
40
+ _conn: Optional[connector.SnowflakeConnection] = None
41
+
42
+
43
+ def clear_cached_conn() -> None:
44
+ """Clear the cached Snowflake connection. Primarily for testing purposes."""
45
+ global _conn
46
+ if _conn is not None and _conn.is_valid():
47
+ _conn.close()
48
+ _conn = None
49
+
50
+
51
+ def get_cached_conn() -> Optional[connector.SnowflakeConnection]:
52
+ """Get the cached Snowflake connection. Primarily for testing purposes."""
53
+ global _conn
54
+ return _conn
55
+
40
56
 
41
57
  def _get_login_token() -> Union[str, bytes]:
42
58
  with open("/snowflake/session/token") as f:
@@ -44,7 +60,11 @@ def _get_login_token() -> Union[str, bytes]:
44
60
 
45
61
 
46
62
  def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
47
- conn = None
63
+ global _conn
64
+ if _conn is not None and _conn.is_valid():
65
+ return _conn
66
+
67
+ conn: Optional[connector.SnowflakeConnection] = None
48
68
  if os.getenv("SNOWFLAKE_HOST") is not None and os.getenv("SNOWFLAKE_ACCOUNT") is not None:
49
69
  try:
50
70
  conn = connect(
@@ -66,6 +86,13 @@ def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
66
86
  # Failed to get an active session. No connection available.
67
87
  pass
68
88
 
89
+ # cache the connection if it's a SnowflakeConnection. there is a behavior at runtime where it could be a
90
+ # StoredProcConnection perhaps incorrect type hinting somewhere
91
+ if isinstance(conn, connector.SnowflakeConnection):
92
+ # if _conn was expired, we need to copy telemetry data to new connection
93
+ if _conn is not None and conn is not None:
94
+ conn._telemetry._log_batch.extend(_conn._telemetry._log_batch)
95
+ _conn = conn
69
96
  return conn
70
97
 
71
98
 
@@ -113,6 +140,13 @@ class TelemetryField(enum.Enum):
113
140
  FUNC_CAT_USAGE = "usage"
114
141
 
115
142
 
143
+ @enum.unique
144
+ class CustomTagKey(enum.Enum):
145
+ """Keys for custom tags in telemetry."""
146
+
147
+ EXECUTION_CONTEXT = "execution_context"
148
+
149
+
116
150
  class _TelemetrySourceType(enum.Enum):
117
151
  # Automatically inferred telemetry/statement parameters
118
152
  AUTO_TELEMETRY = "SNOWML_AUTO_TELEMETRY"
@@ -441,6 +475,7 @@ def send_api_usage_telemetry(
441
475
  sfqids_extractor: Optional[Callable[..., list[str]]] = None,
442
476
  subproject_extractor: Optional[Callable[[Any], str]] = None,
443
477
  custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
478
+ log_execution_context: bool = True,
444
479
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
445
480
  """
446
481
  Decorator that sends API usage telemetry and adds function usage statement parameters to the dataframe returned by
@@ -455,6 +490,8 @@ def send_api_usage_telemetry(
455
490
  sfqids_extractor: Extract sfqids from `self`.
456
491
  subproject_extractor: Extract subproject at runtime from `self`.
457
492
  custom_tags: Custom tags.
493
+ log_execution_context: If True, automatically detect and log execution context
494
+ (EXTERNAL, SPCS, or SPROC) in custom_tags.
458
495
 
459
496
  Returns:
460
497
  Decorator that sends function usage telemetry for any call to the decorated function.
@@ -495,6 +532,11 @@ def send_api_usage_telemetry(
495
532
  if subproject_extractor is not None:
496
533
  subproject_name = subproject_extractor(args[0])
497
534
 
535
+ # Add execution context if enabled
536
+ final_custom_tags = {**custom_tags} if custom_tags is not None else {}
537
+ if log_execution_context:
538
+ final_custom_tags[CustomTagKey.EXECUTION_CONTEXT.value] = env_utils.get_execution_context()
539
+
498
540
  statement_params = get_function_usage_statement_params(
499
541
  project=project,
500
542
  subproject=subproject_name,
@@ -502,7 +544,7 @@ def send_api_usage_telemetry(
502
544
  function_name=_get_full_func_name(func),
503
545
  function_parameters=params,
504
546
  api_calls=api_calls,
505
- custom_tags=custom_tags,
547
+ custom_tags=final_custom_tags,
506
548
  )
507
549
 
508
550
  def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: dict[str, Any]) -> _ReturnValue:
@@ -538,7 +580,10 @@ def send_api_usage_telemetry(
538
580
  if conn_attr_name:
539
581
  # raise AttributeError if conn attribute does not exist in `self`
540
582
  conn = operator.attrgetter(conn_attr_name)(args[0])
541
- if not isinstance(conn, _CONNECTION_TYPES.get(type(conn).__name__, connector.SnowflakeConnection)):
583
+ if not isinstance(
584
+ conn,
585
+ _CONNECTION_TYPES.get(type(conn).__name__, connector.SnowflakeConnection),
586
+ ):
542
587
  raise TypeError(
543
588
  f"Expected a conn object of type {' or '.join(_CONNECTION_TYPES.keys())} but got {type(conn)}"
544
589
  )
@@ -560,7 +605,7 @@ def send_api_usage_telemetry(
560
605
  func_params=params,
561
606
  api_calls=api_calls,
562
607
  sfqids=sfqids,
563
- custom_tags=custom_tags,
608
+ custom_tags=final_custom_tags,
564
609
  )
565
610
  try:
566
611
  return ctx.run(execute_func_with_statement_params)
@@ -571,7 +616,8 @@ def send_api_usage_telemetry(
571
616
  raise
572
617
  if isinstance(e, snowpark_exceptions.SnowparkClientException):
573
618
  me = snowml_exceptions.SnowflakeMLException(
574
- error_code=error_codes.INTERNAL_SNOWPARK_ERROR, original_exception=e
619
+ error_code=error_codes.INTERNAL_SNOWPARK_ERROR,
620
+ original_exception=e,
575
621
  )
576
622
  else:
577
623
  me = snowml_exceptions.SnowflakeMLException(
@@ -627,7 +673,10 @@ def _get_full_func_name(func: Callable[..., Any]) -> str:
627
673
 
628
674
 
629
675
  def _get_func_params(
630
- func: Callable[..., Any], func_params_to_log: Optional[Iterable[str]], args: Any, kwargs: Any
676
+ func: Callable[..., Any],
677
+ func_params_to_log: Optional[Iterable[str]],
678
+ args: Any,
679
+ kwargs: Any,
631
680
  ) -> dict[str, Any]:
632
681
  """
633
682
  Get function parameters.
@@ -128,13 +128,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
128
128
  run_name: sql_identifier.SqlIdentifier,
129
129
  artifact_path: str,
130
130
  ) -> list[artifact.ArtifactInfo]:
131
- results = (
132
- query_result_checker.SqlResultValidator(
133
- self._session, f"LIST {self._build_snow_uri(experiment_name, run_name, artifact_path)}"
134
- )
135
- .has_dimensions(expected_cols=4)
136
- .validate()
137
- )
131
+ results = self._session.sql(f"LIST {self._build_snow_uri(experiment_name, run_name, artifact_path)}").collect()
138
132
  return [
139
133
  artifact.ArtifactInfo(
140
134
  name=str(result.name).removeprefix(f"/versions/{run_name}/"),
@@ -1,4 +1,5 @@
1
1
  import types
2
+ import warnings
2
3
  from typing import TYPE_CHECKING, Optional
3
4
 
4
5
  from snowflake.ml._internal.utils import sql_identifier
@@ -7,6 +8,8 @@ from snowflake.ml.experiment import _experiment_info as experiment_info
7
8
  if TYPE_CHECKING:
8
9
  from snowflake.ml.experiment import experiment_tracking
9
10
 
11
+ METADATA_SIZE_WARNING_MESSAGE = "It is likely that no further metrics or parameters will be logged for this run."
12
+
10
13
 
11
14
  class Run:
12
15
  def __init__(
@@ -20,6 +23,9 @@ class Run:
20
23
  self.experiment_name = experiment_name
21
24
  self.name = run_name
22
25
 
26
+ # Whether we've already shown the user a warning about exceeding the run metadata size limit.
27
+ self._warned_about_metadata_size = False
28
+
23
29
  self._patcher = experiment_info.ExperimentInfoPatcher(
24
30
  experiment_info=self._get_experiment_info(),
25
31
  )
@@ -45,3 +51,12 @@ class Run:
45
51
  ),
46
52
  run_name=self.name.identifier(),
47
53
  )
54
+
55
+ def _warn_about_run_metadata_size(self, sql_error_msg: str) -> None:
56
+ if not self._warned_about_metadata_size:
57
+ warnings.warn(
58
+ f"{sql_error_msg}. {METADATA_SIZE_WARNING_MESSAGE}",
59
+ RuntimeWarning,
60
+ stacklevel=2,
61
+ )
62
+ self._warned_about_metadata_size = True
@@ -1,13 +1,13 @@
1
1
  import functools
2
2
  import json
3
3
  import sys
4
- from typing import Any, Callable, Concatenate, Optional, ParamSpec, TypeVar, Union
4
+ from typing import Any, Optional, Union
5
5
  from urllib.parse import quote
6
6
 
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml import model as ml_model, registry
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
- from snowflake.ml._internal.utils import mixins, sql_identifier
10
+ from snowflake.ml._internal.utils import connection_params, sql_identifier
11
11
  from snowflake.ml.experiment import (
12
12
  _entities as entities,
13
13
  _experiment_info as experiment_info,
@@ -21,34 +21,12 @@ from snowflake.ml.utils import sql_client as sql_client_utils
21
21
 
22
22
  DEFAULT_EXPERIMENT_NAME = sql_identifier.SqlIdentifier("DEFAULT")
23
23
 
24
- P = ParamSpec("P")
25
- T = TypeVar("T")
26
24
 
27
-
28
- def _restore_session(
29
- func: Callable[Concatenate["ExperimentTracking", P], T],
30
- ) -> Callable[Concatenate["ExperimentTracking", P], T]:
31
- @functools.wraps(func)
32
- def wrapper(self: "ExperimentTracking", /, *args: P.args, **kwargs: P.kwargs) -> T:
33
- if self._session is None:
34
- if self._session_state is None:
35
- raise RuntimeError(
36
- f"Session is not set before calling {func.__name__}, and there is no session state to restore from"
37
- )
38
- self._set_session(self._session_state)
39
- if self._session is None:
40
- raise RuntimeError(f"Failed to restore session before calling {func.__name__}")
41
- return func(self, *args, **kwargs)
42
-
43
- return wrapper
44
-
45
-
46
- class ExperimentTracking(mixins.SerializableSessionMixin):
25
+ class ExperimentTracking:
47
26
  """
48
27
  Class to manage experiments in Snowflake.
49
28
  """
50
29
 
51
- @snowpark._internal.utils.private_preview(version="1.9.1")
52
30
  def __init__(
53
31
  self,
54
32
  session: snowpark.Session,
@@ -93,10 +71,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
93
71
  database_name=self._database_name,
94
72
  schema_name=self._schema_name,
95
73
  )
96
- self._session: Optional[snowpark.Session] = session
97
- # Used to store information about the session if the session could not be restored during unpickling
98
- # _session_state is None if and only if _session is not None
99
- self._session_state: Optional[mixins._SessionState] = None
74
+ self._session = session
100
75
 
101
76
  # The experiment in context
102
77
  self._experiment: Optional[entities.Experiment] = None
@@ -104,35 +79,40 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
104
79
  self._run: Optional[entities.Run] = None
105
80
 
106
81
  def __getstate__(self) -> dict[str, Any]:
107
- state = super().__getstate__()
82
+ parent_state = (
83
+ super().__getstate__() # type: ignore[misc] # object.__getstate__ appears in 3.11
84
+ if hasattr(super(), "__getstate__")
85
+ else self.__dict__
86
+ )
87
+ state = dict(parent_state) # Create a copy so we can safely modify the state
88
+
108
89
  # Remove unpicklable attributes
90
+ state["_session"] = None
109
91
  state["_sql_client"] = None
110
92
  state["_registry"] = None
111
93
  return state
112
94
 
113
- def _set_session(self, session_state: mixins._SessionState) -> None:
114
- try:
115
- super()._set_session(session_state)
116
- assert self._session is not None
117
- except (snowpark.exceptions.SnowparkSessionException, AssertionError):
118
- # If session was not set, store the session state
119
- self._session = None
120
- self._session_state = session_state
95
+ def __setstate__(self, state: dict[str, Any]) -> None:
96
+ if hasattr(super(), "__setstate__"):
97
+ super().__setstate__(state) # type: ignore[misc]
121
98
  else:
122
- # If session was set, clear the session state, and reinitialize the SQL client and registry
123
- self._session_state = None
124
- self._sql_client = sql_client.ExperimentTrackingSQLClient(
125
- session=self._session,
126
- database_name=self._database_name,
127
- schema_name=self._schema_name,
128
- )
129
- self._registry = registry.Registry(
130
- session=self._session,
131
- database_name=self._database_name,
132
- schema_name=self._schema_name,
133
- )
99
+ self.__dict__.update(state)
100
+
101
+ # Restore unpicklable attributes
102
+ options: dict[str, Any] = connection_params.SnowflakeLoginOptions()
103
+ options["client_session_keep_alive"] = True # Needed for long-running training jobs
104
+ self._session = snowpark.Session.builder.configs(options).getOrCreate()
105
+ self._sql_client = sql_client.ExperimentTrackingSQLClient(
106
+ session=self._session,
107
+ database_name=self._database_name,
108
+ schema_name=self._schema_name,
109
+ )
110
+ self._registry = registry.Registry(
111
+ session=self._session,
112
+ database_name=self._database_name,
113
+ schema_name=self._schema_name,
114
+ )
134
115
 
135
- @_restore_session
136
116
  def set_experiment(
137
117
  self,
138
118
  experiment_name: str,
@@ -157,7 +137,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
157
137
  self._run = None
158
138
  return self._experiment
159
139
 
160
- @_restore_session
161
140
  def delete_experiment(
162
141
  self,
163
142
  experiment_name: str,
@@ -174,10 +153,8 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
174
153
  self._run = None
175
154
 
176
155
  @functools.wraps(registry.Registry.log_model)
177
- @_restore_session
178
156
  def log_model(
179
157
  self,
180
- /, # self needs to be a positional argument to stop mypy from complaining
181
158
  model: Union[type_hints.SupportedModelType, ml_model.ModelVersion],
182
159
  *,
183
160
  model_name: str,
@@ -187,7 +164,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
187
164
  with experiment_info.ExperimentInfoPatcher(experiment_info=run._get_experiment_info()):
188
165
  return self._registry.log_model(model, model_name=model_name, **kwargs)
189
166
 
190
- @_restore_session
191
167
  def start_run(
192
168
  self,
193
169
  run_name: Optional[str] = None,
@@ -229,7 +205,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
229
205
  self._run = entities.Run(experiment_tracking=self, experiment_name=experiment.name, run_name=run_name)
230
206
  return self._run
231
207
 
232
- @_restore_session
233
208
  def end_run(self, run_name: Optional[str] = None) -> None:
234
209
  """
235
210
  End the current run if no run name is provided. Otherwise, the specified run is ended.
@@ -259,7 +234,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
259
234
  self._run = None
260
235
  self._print_urls(experiment_name=experiment_name, run_name=run_name)
261
236
 
262
- @_restore_session
263
237
  def delete_run(
264
238
  self,
265
239
  run_name: str,
@@ -298,7 +272,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
298
272
  """
299
273
  self.log_metrics(metrics={key: value}, step=step)
300
274
 
301
- @_restore_session
302
275
  def log_metrics(
303
276
  self,
304
277
  metrics: dict[str, float],
@@ -310,16 +283,26 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
310
283
  Args:
311
284
  metrics: Dictionary containing metric keys and float values.
312
285
  step: The step of the metrics. Defaults to 0.
286
+
287
+ Raises:
288
+ snowpark.exceptions.SnowparkSQLException: If logging metrics fails due to Snowflake SQL errors,
289
+ except for run metadata size limit errors which will issue a warning instead of raising.
313
290
  """
314
291
  run = self._get_or_start_run()
315
292
  metrics_list = []
316
293
  for key, value in metrics.items():
317
294
  metrics_list.append(entities.Metric(key, value, step))
318
- self._sql_client.modify_run_add_metrics(
319
- experiment_name=run.experiment_name,
320
- run_name=run.name,
321
- metrics=json.dumps([metric.to_dict() for metric in metrics_list]),
322
- )
295
+ try:
296
+ self._sql_client.modify_run_add_metrics(
297
+ experiment_name=run.experiment_name,
298
+ run_name=run.name,
299
+ metrics=json.dumps([metric.to_dict() for metric in metrics_list]),
300
+ )
301
+ except snowpark.exceptions.SnowparkSQLException as e:
302
+ if e.sql_error_code == 400003: # EXPERIMENT_RUN_PROPERTY_SIZE_LIMIT_EXCEEDED
303
+ run._warn_about_run_metadata_size(e.message)
304
+ else:
305
+ raise
323
306
 
324
307
  def log_param(
325
308
  self,
@@ -335,7 +318,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
335
318
  """
336
319
  self.log_params({key: value})
337
320
 
338
- @_restore_session
339
321
  def log_params(
340
322
  self,
341
323
  params: dict[str, Any],
@@ -346,18 +328,27 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
346
328
  Args:
347
329
  params: Dictionary containing parameter keys and values. Values can be of any type, but will be converted
348
330
  to string.
331
+
332
+ Raises:
333
+ snowpark.exceptions.SnowparkSQLException: If logging parameters fails due to Snowflake SQL errors,
334
+ except for run metadata size limit errors which will issue a warning instead of raising.
349
335
  """
350
336
  run = self._get_or_start_run()
351
337
  params_list = []
352
338
  for key, value in params.items():
353
339
  params_list.append(entities.Param(key, str(value)))
354
- self._sql_client.modify_run_add_params(
355
- experiment_name=run.experiment_name,
356
- run_name=run.name,
357
- params=json.dumps([param.to_dict() for param in params_list]),
358
- )
340
+ try:
341
+ self._sql_client.modify_run_add_params(
342
+ experiment_name=run.experiment_name,
343
+ run_name=run.name,
344
+ params=json.dumps([param.to_dict() for param in params_list]),
345
+ )
346
+ except snowpark.exceptions.SnowparkSQLException as e:
347
+ if e.sql_error_code == 400003: # EXPERIMENT_RUN_PROPERTY_SIZE_LIMIT_EXCEEDED
348
+ run._warn_about_run_metadata_size(e.message)
349
+ else:
350
+ raise
359
351
 
360
- @_restore_session
361
352
  def log_artifact(
362
353
  self,
363
354
  local_path: str,
@@ -381,7 +372,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
381
372
  file_path=file_path,
382
373
  )
383
374
 
384
- @_restore_session
385
375
  def list_artifacts(
386
376
  self,
387
377
  run_name: str,
@@ -410,7 +400,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
410
400
  artifact_path=artifact_path or "",
411
401
  )
412
402
 
413
- @_restore_session
414
403
  def download_artifacts(
415
404
  self,
416
405
  run_name: str,
@@ -452,7 +441,6 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
452
441
  return self._run
453
442
  return self.start_run()
454
443
 
455
- @_restore_session
456
444
  def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
457
445
  generator = hrid_generator.HRID16()
458
446
  existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
@@ -202,6 +202,7 @@ def _configure_role_hierarchy(
202
202
  session.sql(f"GRANT ROLE {producer_role} TO ROLE {session.get_current_role()}").collect()
203
203
 
204
204
  if consumer_role is not None:
205
+ # Create CONSUMER and grant it to PRODUCER to build hierarchy
205
206
  consumer_role = SqlIdentifier(consumer_role)
206
207
  session.sql(f"CREATE ROLE IF NOT EXISTS {consumer_role}").collect()
207
208
  session.sql(f"GRANT ROLE {consumer_role} TO ROLE {producer_role}").collect()