snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +54 -42
- snowflake/ml/_internal/utils/service_logger.py +105 -3
- snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
- snowflake/ml/data/data_connector.py +13 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +2 -1
- snowflake/ml/dataset/dataset_reader.py +14 -4
- snowflake/ml/experiment/__init__.py +3 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/callback.py +121 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +15 -4
- snowflake/ml/jobs/_utils/payload_utils.py +156 -54
- snowflake/ml/jobs/_utils/query_helper.py +16 -5
- snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
- snowflake/ml/jobs/_utils/spec_utils.py +23 -8
- snowflake/ml/jobs/_utils/stage_utils.py +30 -14
- snowflake/ml/jobs/_utils/types.py +64 -4
- snowflake/ml/jobs/job.py +70 -75
- snowflake/ml/jobs/manager.py +59 -31
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/service_ops.py +336 -137
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
- snowflake/ml/model/_client/sql/service.py +1 -38
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/pandas_handler.py +3 -0
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +11 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +32 -15
- snowflake/ml/registry/registry.py +48 -80
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import json
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Any, Optional, Union
|
|
5
|
+
from urllib.parse import quote
|
|
6
|
+
|
|
7
|
+
import snowflake.snowpark._internal.utils as snowpark_utils
|
|
8
|
+
from snowflake.ml import model, registry
|
|
9
|
+
from snowflake.ml._internal.human_readable_id import hrid_generator
|
|
10
|
+
from snowflake.ml._internal.utils import sql_identifier
|
|
11
|
+
from snowflake.ml.experiment import (
|
|
12
|
+
_entities as entities,
|
|
13
|
+
_experiment_info as experiment_info,
|
|
14
|
+
)
|
|
15
|
+
from snowflake.ml.experiment._client import experiment_tracking_sql_client as sql_client
|
|
16
|
+
from snowflake.ml.model import type_hints
|
|
17
|
+
from snowflake.ml.utils import sql_client as sql_client_utils
|
|
18
|
+
from snowflake.snowpark import session
|
|
19
|
+
|
|
20
|
+
DEFAULT_EXPERIMENT_NAME = sql_identifier.SqlIdentifier("DEFAULT")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ExperimentTracking:
|
|
24
|
+
"""
|
|
25
|
+
Class to manage experiments in Snowflake.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@snowpark_utils.private_preview(version="1.9.1")
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
session: session.Session,
|
|
32
|
+
*,
|
|
33
|
+
database_name: Optional[str] = None,
|
|
34
|
+
schema_name: Optional[str] = None,
|
|
35
|
+
) -> None:
|
|
36
|
+
"""
|
|
37
|
+
Initializes experiment tracking within a pre-created schema.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
session: The Snowpark Session to connect with Snowflake.
|
|
41
|
+
database_name: The name of the database. If None, the current database of the session
|
|
42
|
+
will be used. Defaults to None.
|
|
43
|
+
schema_name: The name of the schema. If None, the current schema of the session
|
|
44
|
+
will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: If no database is provided and no active database exists in the session.
|
|
48
|
+
"""
|
|
49
|
+
if database_name:
|
|
50
|
+
self._database_name = sql_identifier.SqlIdentifier(database_name)
|
|
51
|
+
elif session_db := session.get_current_database():
|
|
52
|
+
self._database_name = sql_identifier.SqlIdentifier(session_db)
|
|
53
|
+
else:
|
|
54
|
+
raise ValueError("You need to provide a database to use experiment tracking.")
|
|
55
|
+
|
|
56
|
+
if schema_name:
|
|
57
|
+
self._schema_name = sql_identifier.SqlIdentifier(schema_name)
|
|
58
|
+
elif session_schema := session.get_current_schema():
|
|
59
|
+
self._schema_name = sql_identifier.SqlIdentifier(session_schema)
|
|
60
|
+
else:
|
|
61
|
+
self._schema_name = sql_identifier.SqlIdentifier("PUBLIC")
|
|
62
|
+
|
|
63
|
+
self._sql_client = sql_client.ExperimentTrackingSQLClient(
|
|
64
|
+
session,
|
|
65
|
+
database_name=self._database_name,
|
|
66
|
+
schema_name=self._schema_name,
|
|
67
|
+
)
|
|
68
|
+
self._registry = registry.Registry(
|
|
69
|
+
session=session,
|
|
70
|
+
database_name=self._database_name,
|
|
71
|
+
schema_name=self._schema_name,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# The experiment in context
|
|
75
|
+
self._experiment: Optional[entities.Experiment] = None
|
|
76
|
+
# The run in context
|
|
77
|
+
self._run: Optional[entities.Run] = None
|
|
78
|
+
|
|
79
|
+
def set_experiment(
|
|
80
|
+
self,
|
|
81
|
+
experiment_name: str,
|
|
82
|
+
) -> entities.Experiment:
|
|
83
|
+
"""
|
|
84
|
+
Set the experiment in context. Creates a new experiment if it doesn't exist.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
experiment_name: The name of the experiment.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Experiment: The experiment that was set.
|
|
91
|
+
"""
|
|
92
|
+
experiment_name = sql_identifier.SqlIdentifier(experiment_name)
|
|
93
|
+
if self._experiment and self._experiment.name == experiment_name:
|
|
94
|
+
return self._experiment
|
|
95
|
+
self._sql_client.create_experiment(
|
|
96
|
+
experiment_name=experiment_name,
|
|
97
|
+
creation_mode=sql_client_utils.CreationMode(if_not_exists=True),
|
|
98
|
+
)
|
|
99
|
+
self._experiment = entities.Experiment(experiment_name=experiment_name)
|
|
100
|
+
self._run = None
|
|
101
|
+
return self._experiment
|
|
102
|
+
|
|
103
|
+
def delete_experiment(
|
|
104
|
+
self,
|
|
105
|
+
experiment_name: str,
|
|
106
|
+
) -> None:
|
|
107
|
+
"""
|
|
108
|
+
Delete an experiment.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
experiment_name: The name of the experiment.
|
|
112
|
+
"""
|
|
113
|
+
self._sql_client.drop_experiment(experiment_name=sql_identifier.SqlIdentifier(experiment_name))
|
|
114
|
+
if self._experiment and self._experiment.name == experiment_name:
|
|
115
|
+
self._experiment = None
|
|
116
|
+
self._run = None
|
|
117
|
+
|
|
118
|
+
@functools.wraps(registry.Registry.log_model)
|
|
119
|
+
def log_model(
|
|
120
|
+
self,
|
|
121
|
+
model: Union[type_hints.SupportedModelType, model.ModelVersion],
|
|
122
|
+
*,
|
|
123
|
+
model_name: str,
|
|
124
|
+
**kwargs: Any,
|
|
125
|
+
) -> model.ModelVersion:
|
|
126
|
+
run = self._get_or_start_run()
|
|
127
|
+
with experiment_info.ExperimentInfoPatcher(experiment_info=run._get_experiment_info()):
|
|
128
|
+
return self._registry.log_model(model, model_name=model_name, **kwargs)
|
|
129
|
+
|
|
130
|
+
def start_run(
|
|
131
|
+
self,
|
|
132
|
+
run_name: Optional[str] = None,
|
|
133
|
+
) -> entities.Run:
|
|
134
|
+
"""
|
|
135
|
+
Start a new run.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
run_name: The name of the run. If None, a default name will be generated.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Run: The run that was started.
|
|
142
|
+
|
|
143
|
+
Raises:
|
|
144
|
+
RuntimeError: If a run is already active.
|
|
145
|
+
"""
|
|
146
|
+
if self._run:
|
|
147
|
+
raise RuntimeError("A run is already active. Please end the current run before starting a new one.")
|
|
148
|
+
experiment = self._get_or_set_experiment()
|
|
149
|
+
run_name = (
|
|
150
|
+
sql_identifier.SqlIdentifier(run_name) if run_name is not None else self._generate_run_name(experiment)
|
|
151
|
+
)
|
|
152
|
+
self._sql_client.add_run(
|
|
153
|
+
experiment_name=experiment.name,
|
|
154
|
+
run_name=run_name,
|
|
155
|
+
)
|
|
156
|
+
self._run = entities.Run(experiment_tracking=self, experiment_name=experiment.name, run_name=run_name)
|
|
157
|
+
return self._run
|
|
158
|
+
|
|
159
|
+
def end_run(self, run_name: Optional[str] = None) -> None:
|
|
160
|
+
"""
|
|
161
|
+
End the current run if no run name is provided. Otherwise, the specified run is ended.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
run_name: The name of the run to be ended. If None, the current run is ended.
|
|
165
|
+
|
|
166
|
+
Raises:
|
|
167
|
+
RuntimeError: If no run is active.
|
|
168
|
+
"""
|
|
169
|
+
if not self._experiment:
|
|
170
|
+
raise RuntimeError("No experiment set. Please set an experiment before ending a run.")
|
|
171
|
+
experiment_name = self._experiment.name
|
|
172
|
+
|
|
173
|
+
if run_name:
|
|
174
|
+
run_name = sql_identifier.SqlIdentifier(run_name)
|
|
175
|
+
elif self._run:
|
|
176
|
+
run_name = self._run.name
|
|
177
|
+
else:
|
|
178
|
+
raise RuntimeError("No run is active. Please start a run before ending it.")
|
|
179
|
+
|
|
180
|
+
self._sql_client.commit_run(
|
|
181
|
+
experiment_name=experiment_name,
|
|
182
|
+
run_name=run_name,
|
|
183
|
+
)
|
|
184
|
+
if self._run and run_name == self._run.name:
|
|
185
|
+
self._run = None
|
|
186
|
+
self._print_urls(experiment_name=experiment_name, run_name=run_name)
|
|
187
|
+
|
|
188
|
+
def delete_run(
|
|
189
|
+
self,
|
|
190
|
+
run_name: str,
|
|
191
|
+
) -> None:
|
|
192
|
+
"""
|
|
193
|
+
Delete a run.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
run_name: The name of the run to be deleted.
|
|
197
|
+
|
|
198
|
+
Raises:
|
|
199
|
+
RuntimeError: If no experiment is set.
|
|
200
|
+
"""
|
|
201
|
+
if not self._experiment:
|
|
202
|
+
raise RuntimeError("No experiment set. Please set an experiment before deleting a run.")
|
|
203
|
+
self._sql_client.drop_run(
|
|
204
|
+
experiment_name=self._experiment.name,
|
|
205
|
+
run_name=sql_identifier.SqlIdentifier(run_name),
|
|
206
|
+
)
|
|
207
|
+
if self._run and self._run.name == run_name:
|
|
208
|
+
self._run = None
|
|
209
|
+
|
|
210
|
+
def log_metric(
|
|
211
|
+
self,
|
|
212
|
+
key: str,
|
|
213
|
+
value: float,
|
|
214
|
+
step: int = 0,
|
|
215
|
+
) -> None:
|
|
216
|
+
"""
|
|
217
|
+
Log a metric under the current run. If no run is active, this method will create a new run.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
key: The name of the metric.
|
|
221
|
+
value: The value of the metric.
|
|
222
|
+
step: The step of the metric. Defaults to 0.
|
|
223
|
+
"""
|
|
224
|
+
self.log_metrics(metrics={key: value}, step=step)
|
|
225
|
+
|
|
226
|
+
def log_metrics(
|
|
227
|
+
self,
|
|
228
|
+
metrics: dict[str, float],
|
|
229
|
+
step: int = 0,
|
|
230
|
+
) -> None:
|
|
231
|
+
"""
|
|
232
|
+
Log metrics under the current run. If no run is active, this method will create a new run.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
metrics: Dictionary containing metric keys and float values.
|
|
236
|
+
step: The step of the metrics. Defaults to 0.
|
|
237
|
+
"""
|
|
238
|
+
run = self._get_or_start_run()
|
|
239
|
+
metadata = run._get_metadata()
|
|
240
|
+
for key, value in metrics.items():
|
|
241
|
+
metadata.set_metric(key, value, step)
|
|
242
|
+
self._sql_client.modify_run(
|
|
243
|
+
experiment_name=run.experiment_name,
|
|
244
|
+
run_name=run.name,
|
|
245
|
+
run_metadata=json.dumps(metadata.to_dict()),
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def log_param(
|
|
249
|
+
self,
|
|
250
|
+
key: str,
|
|
251
|
+
value: Any,
|
|
252
|
+
) -> None:
|
|
253
|
+
"""
|
|
254
|
+
Log a parameter under the current run. If no run is active, this method will create a new run.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
key: The name of the parameter.
|
|
258
|
+
value: The value of the parameter. Values can be of any type, but will be converted to string.
|
|
259
|
+
"""
|
|
260
|
+
self.log_params({key: value})
|
|
261
|
+
|
|
262
|
+
def log_params(
|
|
263
|
+
self,
|
|
264
|
+
params: dict[str, Any],
|
|
265
|
+
) -> None:
|
|
266
|
+
"""
|
|
267
|
+
Log parameters under the current run. If no run is active, this method will create a new run.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
params: Dictionary containing parameter keys and values. Values can be of any type, but will be converted
|
|
271
|
+
to string.
|
|
272
|
+
"""
|
|
273
|
+
run = self._get_or_start_run()
|
|
274
|
+
metadata = run._get_metadata()
|
|
275
|
+
for key, value in params.items():
|
|
276
|
+
metadata.set_param(key, value)
|
|
277
|
+
self._sql_client.modify_run(
|
|
278
|
+
experiment_name=run.experiment_name,
|
|
279
|
+
run_name=run.name,
|
|
280
|
+
run_metadata=json.dumps(metadata.to_dict()),
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def _get_or_set_experiment(self) -> entities.Experiment:
|
|
284
|
+
if self._experiment:
|
|
285
|
+
return self._experiment
|
|
286
|
+
return self.set_experiment(experiment_name=DEFAULT_EXPERIMENT_NAME)
|
|
287
|
+
|
|
288
|
+
def _get_or_start_run(self) -> entities.Run:
|
|
289
|
+
if self._run:
|
|
290
|
+
return self._run
|
|
291
|
+
return self.start_run()
|
|
292
|
+
|
|
293
|
+
def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
|
|
294
|
+
generator = hrid_generator.HRID16()
|
|
295
|
+
existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
|
|
296
|
+
existing_run_names = [row[sql_client.ExperimentTrackingSQLClient.RUN_NAME_COL_NAME] for row in existing_runs]
|
|
297
|
+
for _ in range(1000):
|
|
298
|
+
run_name = generator.generate()[1]
|
|
299
|
+
if run_name not in existing_run_names:
|
|
300
|
+
return sql_identifier.SqlIdentifier(run_name)
|
|
301
|
+
raise RuntimeError("Random run name generation failed.")
|
|
302
|
+
|
|
303
|
+
def _print_urls(
|
|
304
|
+
self,
|
|
305
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
|
306
|
+
run_name: sql_identifier.SqlIdentifier,
|
|
307
|
+
scheme: str = "https",
|
|
308
|
+
host: str = "app.snowflake.com",
|
|
309
|
+
) -> None:
|
|
310
|
+
|
|
311
|
+
experiment_url = (
|
|
312
|
+
f"{scheme}://{host}/_deeplink/#/experiments"
|
|
313
|
+
f"/databases/{quote(str(self._database_name))}"
|
|
314
|
+
f"/schemas/{quote(str(self._schema_name))}"
|
|
315
|
+
f"/experiments/{quote(str(experiment_name))}"
|
|
316
|
+
)
|
|
317
|
+
run_url = experiment_url + f"/runs/{quote(str(run_name))}"
|
|
318
|
+
sys.stdout.write(f"🏃 View run {run_name} at: {run_url}\n")
|
|
319
|
+
sys.stdout.write(f"🧪 View experiment at: {experiment_url}\n")
|
|
@@ -6,10 +6,23 @@ DEFAULT_CONTAINER_NAME = "main"
|
|
|
6
6
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
|
7
7
|
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
|
8
8
|
MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
|
|
9
|
+
TARGET_INSTANCES_ENV_VAR = "SNOWFLAKE_JOBS_COUNT"
|
|
9
10
|
RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
|
|
10
11
|
MEMORY_VOLUME_NAME = "dshm"
|
|
11
12
|
STAGE_VOLUME_NAME = "stage-volume"
|
|
12
|
-
|
|
13
|
+
# Base mount path
|
|
14
|
+
STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
|
|
15
|
+
|
|
16
|
+
# Stage subdirectory paths
|
|
17
|
+
APP_STAGE_SUBPATH = "app"
|
|
18
|
+
SYSTEM_STAGE_SUBPATH = "system"
|
|
19
|
+
OUTPUT_STAGE_SUBPATH = "output"
|
|
20
|
+
|
|
21
|
+
# Complete mount paths (automatically generated from base + subpath)
|
|
22
|
+
APP_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{APP_STAGE_SUBPATH}"
|
|
23
|
+
SYSTEM_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{SYSTEM_STAGE_SUBPATH}"
|
|
24
|
+
OUTPUT_MOUNT_PATH = f"{STAGE_VOLUME_MOUNT_PATH}/{OUTPUT_STAGE_SUBPATH}"
|
|
25
|
+
|
|
13
26
|
|
|
14
27
|
# Default container image information
|
|
15
28
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
|
@@ -46,9 +59,7 @@ ENABLE_HEALTH_CHECKS = "false"
|
|
|
46
59
|
JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
|
|
47
60
|
JOB_POLL_MAX_DELAY_SECONDS = 30
|
|
48
61
|
|
|
49
|
-
|
|
50
|
-
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
|
51
|
-
RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
|
|
62
|
+
RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_MOUNT_PATH}/mljob_result.pkl"
|
|
52
63
|
|
|
53
64
|
# Log start and end messages
|
|
54
65
|
LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
|