snowflake-ml-python 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/_internal/telemetry.py +3 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
- snowflake/ml/experiment/callback/keras.py +3 -0
- snowflake/ml/experiment/callback/lightgbm.py +3 -0
- snowflake/ml/experiment/callback/xgboost.py +3 -0
- snowflake/ml/experiment/experiment_tracking.py +19 -7
- snowflake/ml/feature_store/feature_store.py +236 -61
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +16 -2
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +8 -2
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +2 -1
- snowflake/ml/jobs/_utils/stage_utils.py +4 -0
- snowflake/ml/jobs/_utils/types.py +15 -0
- snowflake/ml/jobs/job.py +186 -40
- snowflake/ml/jobs/manager.py +48 -39
- snowflake/ml/model/__init__.py +19 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
- snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
- snowflake/ml/model/_client/model/model_version_impl.py +168 -18
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/ops/service_ops.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +11 -3
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +22 -5
- snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
- snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +7 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/type_hints.py +16 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +50 -4
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +54 -45
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
aerial
|
|
1
2
|
afraid
|
|
2
3
|
ancient
|
|
3
4
|
angry
|
|
@@ -26,7 +27,6 @@ dull
|
|
|
26
27
|
empty
|
|
27
28
|
evil
|
|
28
29
|
fast
|
|
29
|
-
fat
|
|
30
30
|
fluffy
|
|
31
31
|
foolish
|
|
32
32
|
fresh
|
|
@@ -57,10 +57,10 @@ lovely
|
|
|
57
57
|
lucky
|
|
58
58
|
massive
|
|
59
59
|
mean
|
|
60
|
+
metallic
|
|
60
61
|
mighty
|
|
61
62
|
modern
|
|
62
63
|
moody
|
|
63
|
-
nasty
|
|
64
64
|
neat
|
|
65
65
|
nervous
|
|
66
66
|
new
|
|
@@ -85,7 +85,6 @@ rotten
|
|
|
85
85
|
rude
|
|
86
86
|
selfish
|
|
87
87
|
serious
|
|
88
|
-
shaggy
|
|
89
88
|
sharp
|
|
90
89
|
short
|
|
91
90
|
shy
|
|
@@ -96,14 +95,15 @@ slippery
|
|
|
96
95
|
smart
|
|
97
96
|
smooth
|
|
98
97
|
soft
|
|
98
|
+
solid
|
|
99
99
|
sour
|
|
100
100
|
spicy
|
|
101
101
|
splendid
|
|
102
102
|
spotty
|
|
103
|
+
squishy
|
|
103
104
|
stale
|
|
104
105
|
strange
|
|
105
106
|
strong
|
|
106
|
-
stupid
|
|
107
107
|
sweet
|
|
108
108
|
swift
|
|
109
109
|
tall
|
|
@@ -116,7 +116,6 @@ tidy
|
|
|
116
116
|
tiny
|
|
117
117
|
tough
|
|
118
118
|
tricky
|
|
119
|
-
ugly
|
|
120
119
|
warm
|
|
121
120
|
weak
|
|
122
121
|
wet
|
|
@@ -124,5 +123,6 @@ wicked
|
|
|
124
123
|
wise
|
|
125
124
|
witty
|
|
126
125
|
wonderful
|
|
126
|
+
wooden
|
|
127
127
|
yellow
|
|
128
128
|
young
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
anaconda
|
|
2
2
|
ant
|
|
3
|
-
ape
|
|
4
|
-
baboon
|
|
5
3
|
badger
|
|
6
4
|
bat
|
|
7
5
|
bear
|
|
6
|
+
beetle
|
|
8
7
|
bird
|
|
9
8
|
bobcat
|
|
10
9
|
bulldog
|
|
@@ -73,7 +72,6 @@ lobster
|
|
|
73
72
|
mayfly
|
|
74
73
|
mamba
|
|
75
74
|
mole
|
|
76
|
-
monkey
|
|
77
75
|
moose
|
|
78
76
|
moth
|
|
79
77
|
mouse
|
|
@@ -114,6 +112,7 @@ swan
|
|
|
114
112
|
termite
|
|
115
113
|
tiger
|
|
116
114
|
treefrog
|
|
115
|
+
tuna
|
|
117
116
|
turkey
|
|
118
117
|
turtle
|
|
119
118
|
vampirebat
|
|
@@ -126,3 +125,4 @@ worm
|
|
|
126
125
|
yak
|
|
127
126
|
yeti
|
|
128
127
|
zebra
|
|
128
|
+
zebrafish
|
|
@@ -73,6 +73,7 @@ def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
|
|
|
73
73
|
class TelemetryProject(enum.Enum):
|
|
74
74
|
MLOPS = "MLOps"
|
|
75
75
|
MODELING = "ModelDevelopment"
|
|
76
|
+
EXPERIMENT_TRACKING = "ExperimentTracking"
|
|
76
77
|
# TODO: Update with remaining projects.
|
|
77
78
|
|
|
78
79
|
|
|
@@ -464,14 +465,14 @@ def send_api_usage_telemetry(
|
|
|
464
465
|
|
|
465
466
|
# noqa: DAR402
|
|
466
467
|
"""
|
|
467
|
-
start_time = time.perf_counter()
|
|
468
|
-
|
|
469
468
|
if subproject is not None and subproject_extractor is not None:
|
|
470
469
|
raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
|
|
471
470
|
|
|
472
471
|
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
|
|
473
472
|
@functools.wraps(func)
|
|
474
473
|
def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
|
|
474
|
+
start_time = time.perf_counter()
|
|
475
|
+
|
|
475
476
|
params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None
|
|
476
477
|
|
|
477
478
|
api_calls: list[Union[dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = []
|
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
|
+
from snowflake.ml._internal import telemetry
|
|
3
4
|
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
|
4
5
|
from snowflake.ml.experiment._client import artifact
|
|
5
6
|
from snowflake.ml.model._client.sql import _base
|
|
6
7
|
from snowflake.ml.utils import sql_client
|
|
7
8
|
from snowflake.snowpark import file_operation, row, session
|
|
8
9
|
|
|
10
|
+
RUN_NAME_COL_NAME = "name"
|
|
11
|
+
RUN_METADATA_COL_NAME = "metadata"
|
|
9
12
|
|
|
10
|
-
class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
11
|
-
|
|
12
|
-
RUN_NAME_COL_NAME = "name"
|
|
13
|
-
RUN_METADATA_COL_NAME = "metadata"
|
|
14
13
|
|
|
14
|
+
class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
15
15
|
def __init__(
|
|
16
16
|
self,
|
|
17
17
|
session: session.Session,
|
|
@@ -28,6 +28,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
28
28
|
"""
|
|
29
29
|
super().__init__(session, database_name=database_name, schema_name=schema_name)
|
|
30
30
|
|
|
31
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
31
32
|
def create_experiment(
|
|
32
33
|
self,
|
|
33
34
|
experiment_name: sql_identifier.SqlIdentifier,
|
|
@@ -39,24 +40,21 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
39
40
|
self._session, f"CREATE EXPERIMENT {if_not_exists_sql} {experiment_fqn}"
|
|
40
41
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
41
42
|
|
|
43
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
42
44
|
def drop_experiment(self, *, experiment_name: sql_identifier.SqlIdentifier) -> None:
|
|
43
45
|
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
44
46
|
query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
|
|
45
47
|
expected_rows=1, expected_cols=1
|
|
46
48
|
).validate()
|
|
47
49
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
*,
|
|
51
|
-
experiment_name: sql_identifier.SqlIdentifier,
|
|
52
|
-
run_name: sql_identifier.SqlIdentifier,
|
|
53
|
-
live: bool = True,
|
|
54
|
-
) -> None:
|
|
50
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
51
|
+
def add_run(self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier) -> None:
|
|
55
52
|
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
56
53
|
query_result_checker.SqlResultValidator(
|
|
57
|
-
self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD
|
|
54
|
+
self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD RUN {run_name}"
|
|
58
55
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
59
56
|
|
|
57
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
60
58
|
def commit_run(
|
|
61
59
|
self,
|
|
62
60
|
*,
|
|
@@ -68,6 +66,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
68
66
|
self._session, f"ALTER EXPERIMENT {experiment_fqn} COMMIT RUN {run_name}"
|
|
69
67
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
70
68
|
|
|
69
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
71
70
|
def drop_run(
|
|
72
71
|
self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier
|
|
73
72
|
) -> None:
|
|
@@ -76,6 +75,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
76
75
|
self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
|
|
77
76
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
78
77
|
|
|
78
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
79
79
|
def modify_run_add_metrics(
|
|
80
80
|
self,
|
|
81
81
|
*,
|
|
@@ -89,6 +89,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
89
89
|
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD METRICS=$${metrics}$$",
|
|
90
90
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
91
91
|
|
|
92
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
92
93
|
def modify_run_add_params(
|
|
93
94
|
self,
|
|
94
95
|
*,
|
|
@@ -102,6 +103,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
102
103
|
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD PARAMETERS=$${params}$$",
|
|
103
104
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
104
105
|
|
|
106
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
105
107
|
def put_artifact(
|
|
106
108
|
self,
|
|
107
109
|
*,
|
|
@@ -118,6 +120,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
118
120
|
auto_compress=auto_compress,
|
|
119
121
|
)[0]
|
|
120
122
|
|
|
123
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
121
124
|
def list_artifacts(
|
|
122
125
|
self,
|
|
123
126
|
*,
|
|
@@ -142,6 +145,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
142
145
|
for result in results
|
|
143
146
|
]
|
|
144
147
|
|
|
148
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
145
149
|
def get_artifact(
|
|
146
150
|
self,
|
|
147
151
|
*,
|
|
@@ -155,6 +159,7 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
155
159
|
target_directory=target_path,
|
|
156
160
|
)[0]
|
|
157
161
|
|
|
162
|
+
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
158
163
|
def show_runs_in_experiment(
|
|
159
164
|
self, *, experiment_name: sql_identifier.SqlIdentifier, like: Optional[str] = None
|
|
160
165
|
) -> list[row.Row]:
|
|
@@ -20,6 +20,7 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
|
|
|
20
20
|
log_params: bool = True,
|
|
21
21
|
log_every_n_epochs: int = 1,
|
|
22
22
|
model_name: Optional[str] = None,
|
|
23
|
+
version_name: Optional[str] = None,
|
|
23
24
|
model_signature: Optional["ModelSignature"] = None,
|
|
24
25
|
) -> None:
|
|
25
26
|
self._experiment_tracking = experiment_tracking
|
|
@@ -30,6 +31,7 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
|
|
|
30
31
|
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
31
32
|
self.log_every_n_epochs = log_every_n_epochs
|
|
32
33
|
self.model_name = model_name
|
|
34
|
+
self.version_name = version_name
|
|
33
35
|
self.model_signature = model_signature
|
|
34
36
|
|
|
35
37
|
def on_train_begin(self, logs: Optional[dict[str, Any]] = None) -> None:
|
|
@@ -59,5 +61,6 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
|
|
|
59
61
|
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
60
62
|
model=self.model,
|
|
61
63
|
model_name=model_name,
|
|
64
|
+
version_name=self.version_name,
|
|
62
65
|
signatures={"predict": self.model_signature},
|
|
63
66
|
)
|
|
@@ -17,6 +17,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
|
17
17
|
log_params: bool = True,
|
|
18
18
|
log_every_n_epochs: int = 1,
|
|
19
19
|
model_name: Optional[str] = None,
|
|
20
|
+
version_name: Optional[str] = None,
|
|
20
21
|
model_signature: Optional["ModelSignature"] = None,
|
|
21
22
|
) -> None:
|
|
22
23
|
self._experiment_tracking = experiment_tracking
|
|
@@ -27,6 +28,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
|
27
28
|
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
28
29
|
self.log_every_n_epochs = log_every_n_epochs
|
|
29
30
|
self.model_name = model_name
|
|
31
|
+
self.version_name = version_name
|
|
30
32
|
self.model_signature = model_signature
|
|
31
33
|
|
|
32
34
|
super().__init__(eval_result={})
|
|
@@ -50,6 +52,7 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
|
50
52
|
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
51
53
|
model=env.model,
|
|
52
54
|
model_name=model_name,
|
|
55
|
+
version_name=self.version_name,
|
|
53
56
|
signatures={"predict": self.model_signature},
|
|
54
57
|
)
|
|
55
58
|
else:
|
|
@@ -20,6 +20,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
|
20
20
|
log_params: bool = True,
|
|
21
21
|
log_every_n_epochs: int = 1,
|
|
22
22
|
model_name: Optional[str] = None,
|
|
23
|
+
version_name: Optional[str] = None,
|
|
23
24
|
model_signature: Optional["ModelSignature"] = None,
|
|
24
25
|
) -> None:
|
|
25
26
|
self._experiment_tracking = experiment_tracking
|
|
@@ -30,6 +31,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
|
30
31
|
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
31
32
|
self.log_every_n_epochs = log_every_n_epochs
|
|
32
33
|
self.model_name = model_name
|
|
34
|
+
self.version_name = version_name
|
|
33
35
|
self.model_signature = model_signature
|
|
34
36
|
|
|
35
37
|
def before_training(self, model: xgb.Booster) -> xgb.Booster:
|
|
@@ -61,6 +63,7 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
|
61
63
|
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
62
64
|
model=model,
|
|
63
65
|
model_name=model_name,
|
|
66
|
+
version_name=self.version_name,
|
|
64
67
|
signatures={"predict": self.model_signature},
|
|
65
68
|
)
|
|
66
69
|
|
|
@@ -193,23 +193,35 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
193
193
|
run_name: Optional[str] = None,
|
|
194
194
|
) -> entities.Run:
|
|
195
195
|
"""
|
|
196
|
-
Start a new run.
|
|
196
|
+
Start a new run. If a run name of an existing run is provided, resumes the run if it is running.
|
|
197
197
|
|
|
198
198
|
Args:
|
|
199
199
|
run_name: The name of the run. If None, a default name will be generated.
|
|
200
200
|
|
|
201
201
|
Returns:
|
|
202
|
-
Run: The run that was started.
|
|
202
|
+
Run: The run that was started or resumed.
|
|
203
203
|
|
|
204
204
|
Raises:
|
|
205
|
-
RuntimeError: If a run is already active.
|
|
205
|
+
RuntimeError: If a run is already active. If a run with the same name exists but is not running.
|
|
206
206
|
"""
|
|
207
207
|
if self._run:
|
|
208
208
|
raise RuntimeError("A run is already active. Please end the current run before starting a new one.")
|
|
209
209
|
experiment = self._get_or_set_experiment()
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
210
|
+
|
|
211
|
+
if run_name is None:
|
|
212
|
+
run_name = self._generate_run_name(experiment)
|
|
213
|
+
elif runs := self._sql_client.show_runs_in_experiment(experiment_name=experiment.name, like=run_name):
|
|
214
|
+
if "RUNNING" != json.loads(runs[0][sql_client.RUN_METADATA_COL_NAME])["status"]:
|
|
215
|
+
raise RuntimeError(f"Run {run_name} exists but cannot be resumed as it is no longer running.")
|
|
216
|
+
else:
|
|
217
|
+
self._run = entities.Run(
|
|
218
|
+
experiment_tracking=self,
|
|
219
|
+
experiment_name=experiment.name,
|
|
220
|
+
run_name=sql_identifier.SqlIdentifier(run_name),
|
|
221
|
+
)
|
|
222
|
+
return self._run
|
|
223
|
+
|
|
224
|
+
run_name = sql_identifier.SqlIdentifier(run_name)
|
|
213
225
|
self._sql_client.add_run(
|
|
214
226
|
experiment_name=experiment.name,
|
|
215
227
|
run_name=run_name,
|
|
@@ -444,7 +456,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
444
456
|
def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
|
|
445
457
|
generator = hrid_generator.HRID16()
|
|
446
458
|
existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
|
|
447
|
-
existing_run_names = [row[sql_client.
|
|
459
|
+
existing_run_names = [row[sql_client.RUN_NAME_COL_NAME] for row in existing_runs]
|
|
448
460
|
for _ in range(1000):
|
|
449
461
|
run_name = generator.generate()[1]
|
|
450
462
|
if run_name not in existing_run_names:
|