snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +36 -0
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
- snowflake/ml/data/data_connector.py +103 -1
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
- snowflake/ml/experiment/callback/__init__.py +0 -0
- snowflake/ml/experiment/callback/keras.py +25 -2
- snowflake/ml/experiment/callback/lightgbm.py +27 -2
- snowflake/ml/experiment/callback/xgboost.py +25 -2
- snowflake/ml/experiment/experiment_tracking.py +93 -3
- snowflake/ml/experiment/utils.py +6 -0
- snowflake/ml/feature_store/feature_view.py +34 -24
- snowflake/ml/jobs/_interop/protocols.py +3 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +354 -356
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
- snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
- snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
- snowflake/ml/jobs/_utils/spec_utils.py +1 -445
- snowflake/ml/jobs/_utils/stage_utils.py +22 -1
- snowflake/ml/jobs/_utils/types.py +14 -7
- snowflake/ml/jobs/job.py +2 -8
- snowflake/ml/jobs/manager.py +57 -135
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +130 -14
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +93 -8
- snowflake/ml/model/_client/ops/service_ops.py +32 -52
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +94 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
- snowflake/ml/model/_packager/model_handler.py +8 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
- snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +390 -8
- snowflake/ml/model/_signatures/utils.py +13 -4
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/compute_pool.py +2 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/models/huggingface.py +285 -0
- snowflake/ml/model/models/huggingface_pipeline.py +19 -208
- snowflake/ml/model/type_hints.py +7 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/utils/html_utils.py +67 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
- snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
|
@@ -18,6 +18,8 @@ logger = logging.getLogger(__name__)
|
|
|
18
18
|
LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
|
|
19
19
|
INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
|
|
20
20
|
SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST = "SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST"
|
|
21
|
+
ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS = "ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS"
|
|
22
|
+
FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
class PlatformCapabilities:
|
|
@@ -80,6 +82,12 @@ class PlatformCapabilities:
|
|
|
80
82
|
def is_live_commit_enabled(self) -> bool:
|
|
81
83
|
return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
|
|
82
84
|
|
|
85
|
+
def is_model_method_signature_parameters_enabled(self) -> bool:
|
|
86
|
+
return self._get_bool_feature(ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS, False)
|
|
87
|
+
|
|
88
|
+
def is_inference_autocapture_enabled(self) -> bool:
|
|
89
|
+
return self._is_feature_enabled(FEATURE_MODEL_INFERENCE_AUTOCAPTURE)
|
|
90
|
+
|
|
83
91
|
@staticmethod
|
|
84
92
|
def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
|
|
85
93
|
try:
|
|
@@ -182,3 +190,31 @@ class PlatformCapabilities:
|
|
|
182
190
|
f"current={current_version}, feature={feature_version}, enabled={result}"
|
|
183
191
|
)
|
|
184
192
|
return result
|
|
193
|
+
|
|
194
|
+
def _is_feature_enabled(self, feature_name: str) -> bool:
|
|
195
|
+
"""Check if the feature parameter value belongs to enabled values.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
feature_name: The name of the feature to retrieve.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
bool: True if the value is "ENABLED" or "ENABLED_PUBLIC_PREVIEW",
|
|
202
|
+
False if the value is "DISABLED", "DISABLED_PRIVATE_PREVIEW", or not set.
|
|
203
|
+
|
|
204
|
+
Raises:
|
|
205
|
+
ValueError: If the feature value is set but not one of the recognized values.
|
|
206
|
+
"""
|
|
207
|
+
value = self.features.get(feature_name)
|
|
208
|
+
if value is None:
|
|
209
|
+
logger.debug(f"Feature {feature_name} not found.")
|
|
210
|
+
return False
|
|
211
|
+
|
|
212
|
+
if isinstance(value, str):
|
|
213
|
+
value_str = str(value)
|
|
214
|
+
if value_str.upper() in ["ENABLED", "ENABLED_PUBLIC_PREVIEW"]:
|
|
215
|
+
return True
|
|
216
|
+
elif value_str.upper() in ["DISABLED", "DISABLED_PRIVATE_PREVIEW"]:
|
|
217
|
+
return False
|
|
218
|
+
else:
|
|
219
|
+
raise ValueError(f"Invalid feature parameter value: {value} for feature {feature_name}")
|
|
220
|
+
raise ValueError(f"Invalid feature parameter string value: {value} for feature {feature_name}")
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from urllib.parse import urlunparse
|
|
2
|
+
|
|
3
|
+
from snowflake import snowpark as snowpark
|
|
4
|
+
|
|
5
|
+
JOB_URL_PREFIX = "#/compute/job/"
|
|
6
|
+
SERVICE_URL_PREFIX = "#/compute/service/"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_snowflake_url(
|
|
10
|
+
session: snowpark.Session,
|
|
11
|
+
url_path: str,
|
|
12
|
+
params: str = "",
|
|
13
|
+
query: str = "",
|
|
14
|
+
fragment: str = "",
|
|
15
|
+
) -> str:
|
|
16
|
+
"""Construct a Snowflake URL from session connection details and URL components.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
session: The Snowpark session containing connection details.
|
|
20
|
+
url_path: The path component of the URL (e.g., "/compute/job/123").
|
|
21
|
+
params: Optional parameters for the URL (RFC 1808). Defaults to "".
|
|
22
|
+
query: Optional query string for the URL. Defaults to "".
|
|
23
|
+
fragment: Optional fragment identifier for the URL (e.g., "#section"). Defaults to "".
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
A fully constructed Snowflake URL string with scheme, host, and specified components.
|
|
27
|
+
"""
|
|
28
|
+
scheme = "https"
|
|
29
|
+
if hasattr(session.connection, "scheme"):
|
|
30
|
+
scheme = session.connection.scheme
|
|
31
|
+
host = session.connection.host
|
|
32
|
+
|
|
33
|
+
return urlunparse(
|
|
34
|
+
(
|
|
35
|
+
scheme,
|
|
36
|
+
host,
|
|
37
|
+
url_path,
|
|
38
|
+
params,
|
|
39
|
+
query,
|
|
40
|
+
fragment,
|
|
41
|
+
)
|
|
42
|
+
)
|
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
import collections
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
5
|
+
import re
|
|
4
6
|
import time
|
|
5
7
|
from typing import TYPE_CHECKING, Any, Deque, Iterator, Optional, Sequence, Union
|
|
6
8
|
|
|
@@ -165,8 +167,71 @@ class ArrowIngestor(data_ingestor.DataIngestor, mixins.SerializableSessionMixin)
|
|
|
165
167
|
# Re-shuffle input files on each iteration start
|
|
166
168
|
if shuffle:
|
|
167
169
|
np.random.shuffle(sources)
|
|
168
|
-
|
|
169
|
-
|
|
170
|
+
try:
|
|
171
|
+
pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
|
|
172
|
+
return pa_dataset
|
|
173
|
+
except Exception as e:
|
|
174
|
+
self._tmp_debug_parquet_invalid(e, sources)
|
|
175
|
+
|
|
176
|
+
def _tmp_debug_parquet_invalid(self, e: Exception, sources: list[Any]) -> None:
|
|
177
|
+
# Attach rich debug info to help diagnose intermittent Parquet footer/magic byte errors
|
|
178
|
+
debug_parts: list[str] = []
|
|
179
|
+
debug_parts.append("SNOWML DEBUG: Failed to construct Arrow Dataset")
|
|
180
|
+
debug_parts.append(
|
|
181
|
+
"SNOWML DEBUG: " f"data_sources_count={len(self._data_sources)} " f"resolved_sources_count={len(sources)}"
|
|
182
|
+
)
|
|
183
|
+
# Try to include the exact file path mentioned by pyarrow, if present
|
|
184
|
+
error_text = str(e)
|
|
185
|
+
snow_paths: list[str] = []
|
|
186
|
+
try:
|
|
187
|
+
# Extract snow://... tokens possibly wrapped in quotes
|
|
188
|
+
for match in re.finditer(r'(snow://[^\s\'"]+)', error_text):
|
|
189
|
+
token = match.group(1).rstrip(").,;]")
|
|
190
|
+
snow_paths.append(token)
|
|
191
|
+
except Exception:
|
|
192
|
+
pass
|
|
193
|
+
fs = self._kwargs.get("filesystem")
|
|
194
|
+
if fs is not None:
|
|
195
|
+
# Always include a directory listing with sizes for context
|
|
196
|
+
try:
|
|
197
|
+
debug_parts.append("SNOWML DEBUG: Listing resolved sources with sizes:")
|
|
198
|
+
for s in sources:
|
|
199
|
+
try:
|
|
200
|
+
info = fs.info(s)
|
|
201
|
+
size = info.get("size", None)
|
|
202
|
+
md5 = info.get("md5", None)
|
|
203
|
+
debug_parts.append(f" - {s} size={size} md5={md5}")
|
|
204
|
+
except Exception as le:
|
|
205
|
+
debug_parts.append(f" - {s} info_failed={le}")
|
|
206
|
+
except Exception as le:
|
|
207
|
+
debug_parts.append(f"SNOWML DEBUG: listing sources failed: {le}")
|
|
208
|
+
# If pyarrow referenced a specific file, dump its full contents (base64) for inspection
|
|
209
|
+
for path in snow_paths[:1]: # usually only one path appears in the message
|
|
210
|
+
try:
|
|
211
|
+
info = fs.info(path)
|
|
212
|
+
size = info.get("size", None)
|
|
213
|
+
debug_parts.append(f"SNOWML DEBUG: Inspecting referenced file: {path} size={size}")
|
|
214
|
+
with fs.open(path, "rb") as f:
|
|
215
|
+
content = f.read()
|
|
216
|
+
magic_head = content[:4]
|
|
217
|
+
magic_tail = content[-4:] if content else b""
|
|
218
|
+
looks_like_parquet = (magic_head == b"PAR1") and (magic_tail == b"PAR1")
|
|
219
|
+
debug_parts.append(
|
|
220
|
+
"SNOWML DEBUG: "
|
|
221
|
+
f"file_magic_head={magic_head!r} "
|
|
222
|
+
f"file_magic_tail={magic_tail!r} "
|
|
223
|
+
f"parquet_magic_detected={looks_like_parquet}"
|
|
224
|
+
)
|
|
225
|
+
b64 = base64.b64encode(content).decode("ascii")
|
|
226
|
+
debug_parts.append("SNOWML DEBUG: file_content_base64 (entire file):")
|
|
227
|
+
debug_parts.append(b64)
|
|
228
|
+
except Exception as fe:
|
|
229
|
+
debug_parts.append(f"SNOWML DEBUG: failed to read referenced file {path}: {fe}")
|
|
230
|
+
else:
|
|
231
|
+
debug_parts.append("SNOWML DEBUG: No filesystem available; cannot inspect files")
|
|
232
|
+
debug_message = "\n".join(debug_parts)
|
|
233
|
+
# Re-raise with augmented message to surface in stacktrace
|
|
234
|
+
raise RuntimeError(f"{e}\n{debug_message}") from e
|
|
170
235
|
|
|
171
236
|
def _get_batches_from_buffer(self, batch_size: int) -> dict[str, npt.NDArray[Any]]:
|
|
172
237
|
"""Generate new batches from the existing record batch buffer."""
|
|
@@ -1,5 +1,15 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import (
|
|
3
|
+
TYPE_CHECKING,
|
|
4
|
+
Any,
|
|
5
|
+
Generator,
|
|
6
|
+
Literal,
|
|
7
|
+
Optional,
|
|
8
|
+
Sequence,
|
|
9
|
+
TypeVar,
|
|
10
|
+
Union,
|
|
11
|
+
overload,
|
|
12
|
+
)
|
|
3
13
|
|
|
4
14
|
import numpy.typing as npt
|
|
5
15
|
from typing_extensions import deprecated
|
|
@@ -11,6 +21,7 @@ from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
|
|
|
11
21
|
from snowflake.snowpark import context as sp_context
|
|
12
22
|
|
|
13
23
|
if TYPE_CHECKING:
|
|
24
|
+
import datasets as hf_datasets
|
|
14
25
|
import pandas as pd
|
|
15
26
|
import ray
|
|
16
27
|
import tensorflow as tf
|
|
@@ -257,6 +268,97 @@ class DataConnector:
|
|
|
257
268
|
except ImportError as e:
|
|
258
269
|
raise ImportError("Ray is not installed, please install ray in your local environment.") from e
|
|
259
270
|
|
|
271
|
+
@overload
|
|
272
|
+
def to_huggingface_dataset(
|
|
273
|
+
self,
|
|
274
|
+
*,
|
|
275
|
+
streaming: Literal[False] = ...,
|
|
276
|
+
limit: Optional[int] = ...,
|
|
277
|
+
) -> "hf_datasets.Dataset":
|
|
278
|
+
...
|
|
279
|
+
|
|
280
|
+
@overload
|
|
281
|
+
def to_huggingface_dataset(
|
|
282
|
+
self,
|
|
283
|
+
*,
|
|
284
|
+
streaming: Literal[True],
|
|
285
|
+
limit: Optional[int] = ...,
|
|
286
|
+
batch_size: int = ...,
|
|
287
|
+
shuffle: bool = ...,
|
|
288
|
+
drop_last_batch: bool = ...,
|
|
289
|
+
) -> "hf_datasets.IterableDataset":
|
|
290
|
+
...
|
|
291
|
+
|
|
292
|
+
@telemetry.send_api_usage_telemetry(
|
|
293
|
+
project=_PROJECT,
|
|
294
|
+
subproject_extractor=lambda self: type(self).__name__,
|
|
295
|
+
func_params_to_log=["streaming", "limit", "batch_size", "shuffle", "drop_last_batch"],
|
|
296
|
+
)
|
|
297
|
+
def to_huggingface_dataset(
|
|
298
|
+
self,
|
|
299
|
+
*,
|
|
300
|
+
streaming: bool = False,
|
|
301
|
+
limit: Optional[int] = None,
|
|
302
|
+
batch_size: int = 1024,
|
|
303
|
+
shuffle: bool = False,
|
|
304
|
+
drop_last_batch: bool = False,
|
|
305
|
+
) -> "Union[hf_datasets.Dataset, hf_datasets.IterableDataset]":
|
|
306
|
+
"""Retrieve the Snowflake data as a HuggingFace Dataset.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
streaming: If True, returns an IterableDataset that streams data in batches.
|
|
310
|
+
If False (default), returns an in-memory Dataset.
|
|
311
|
+
limit: Maximum number of rows to load. If None, loads all rows.
|
|
312
|
+
batch_size: Size of batches for internal data retrieval.
|
|
313
|
+
shuffle: Whether to shuffle the data. If True, files will be shuffled and rows
|
|
314
|
+
in each file will also be shuffled.
|
|
315
|
+
drop_last_batch: Whether to drop the last batch if it's smaller than batch_size.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
A HuggingFace Dataset (in-memory) or IterableDataset (streaming).
|
|
319
|
+
"""
|
|
320
|
+
import datasets as hf_datasets
|
|
321
|
+
|
|
322
|
+
if streaming:
|
|
323
|
+
return self._to_huggingface_iterable_dataset(
|
|
324
|
+
limit=limit,
|
|
325
|
+
batch_size=batch_size,
|
|
326
|
+
shuffle=shuffle,
|
|
327
|
+
drop_last_batch=drop_last_batch,
|
|
328
|
+
)
|
|
329
|
+
else:
|
|
330
|
+
return hf_datasets.Dataset.from_pandas(self._ingestor.to_pandas(limit))
|
|
331
|
+
|
|
332
|
+
def _to_huggingface_iterable_dataset(
|
|
333
|
+
self,
|
|
334
|
+
*,
|
|
335
|
+
limit: Optional[int],
|
|
336
|
+
batch_size: int,
|
|
337
|
+
shuffle: bool,
|
|
338
|
+
drop_last_batch: bool,
|
|
339
|
+
) -> "hf_datasets.IterableDataset":
|
|
340
|
+
"""Create a HuggingFace IterableDataset that streams data in batches."""
|
|
341
|
+
import datasets as hf_datasets
|
|
342
|
+
|
|
343
|
+
def generator() -> Generator[dict[str, Any], None, None]:
|
|
344
|
+
rows_yielded = 0
|
|
345
|
+
for batch in self._ingestor.to_batches(batch_size, shuffle, drop_last_batch):
|
|
346
|
+
# Yield individual rows from each batch
|
|
347
|
+
num_rows = len(next(iter(batch.values())))
|
|
348
|
+
for i in range(num_rows):
|
|
349
|
+
if limit is not None and rows_yielded >= limit:
|
|
350
|
+
return
|
|
351
|
+
yield {k: v[i].item() if hasattr(v[i], "item") else v[i] for k, v in batch.items()}
|
|
352
|
+
rows_yielded += 1
|
|
353
|
+
|
|
354
|
+
result = hf_datasets.IterableDataset.from_generator(generator)
|
|
355
|
+
# In datasets >= 3.x, from_generator returns IterableDatasetDict
|
|
356
|
+
# We need to extract the IterableDataset from the dict
|
|
357
|
+
if hasattr(hf_datasets, "IterableDatasetDict") and isinstance(result, hf_datasets.IterableDatasetDict):
|
|
358
|
+
# Get the first (and only) dataset from the dict
|
|
359
|
+
result = next(iter(result.values()))
|
|
360
|
+
return result
|
|
361
|
+
|
|
260
362
|
|
|
261
363
|
# Switch to use Runtime's Data Ingester if running in ML runtime
|
|
262
364
|
# Fail silently if the data ingester is not found
|
|
@@ -41,8 +41,14 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
41
41
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
42
42
|
|
|
43
43
|
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
44
|
-
def drop_experiment(
|
|
45
|
-
|
|
44
|
+
def drop_experiment(
|
|
45
|
+
self,
|
|
46
|
+
*,
|
|
47
|
+
database_name: sql_identifier.SqlIdentifier,
|
|
48
|
+
schema_name: sql_identifier.SqlIdentifier,
|
|
49
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
|
50
|
+
) -> None:
|
|
51
|
+
experiment_fqn = self.fully_qualified_object_name(database_name, schema_name, experiment_name)
|
|
46
52
|
query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
|
|
47
53
|
expected_rows=1, expected_cols=1
|
|
48
54
|
).validate()
|
|
File without changes
|
|
@@ -12,6 +12,8 @@ if TYPE_CHECKING:
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class SnowflakeKerasCallback(keras.callbacks.Callback):
|
|
15
|
+
"""Keras callback for automatically logging to a Snowflake ML Experiment."""
|
|
16
|
+
|
|
15
17
|
def __init__(
|
|
16
18
|
self,
|
|
17
19
|
experiment_tracking: "ExperimentTracking",
|
|
@@ -23,12 +25,33 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
|
|
|
23
25
|
version_name: Optional[str] = None,
|
|
24
26
|
model_signature: Optional["ModelSignature"] = None,
|
|
25
27
|
) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Creates a new Keras callback.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
experiment_tracking (snowflake.ml.experiment.ExperimentTracking): The Experiment Tracking instance
|
|
33
|
+
to use for logging.
|
|
34
|
+
log_model (bool): Whether to log the model at the end of training. Default is True.
|
|
35
|
+
log_metrics (bool): Whether to log metrics during training. Default is True.
|
|
36
|
+
log_params (bool): Whether to log model parameters at the start of training. Default is True.
|
|
37
|
+
log_every_n_epochs (int): Frequency with which to log metrics. Must be positive.
|
|
38
|
+
Default is 1, logging after every epoch.
|
|
39
|
+
model_name (Optional[str]): The model name to use when logging the model.
|
|
40
|
+
If None, the model name will be derived from the experiment name.
|
|
41
|
+
version_name (Optional[str]): The model version name to use when logging the model.
|
|
42
|
+
If None, the version name will be randomly generated.
|
|
43
|
+
model_signature (Optional[snowflake.ml.model.model_signature.ModelSignature]): The model signature to use
|
|
44
|
+
when logging the model. This is required if ``log_model`` is set to True.
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: When ``log_every_n_epochs`` is not a positive integer.
|
|
48
|
+
"""
|
|
26
49
|
self._experiment_tracking = experiment_tracking
|
|
27
50
|
self.log_model = log_model
|
|
28
51
|
self.log_metrics = log_metrics
|
|
29
52
|
self.log_params = log_params
|
|
30
|
-
if log_every_n_epochs
|
|
31
|
-
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
53
|
+
if not (utils.is_integer(log_every_n_epochs) and log_every_n_epochs > 0):
|
|
54
|
+
raise ValueError("`log_every_n_epochs` must be a positive integer.")
|
|
32
55
|
self.log_every_n_epochs = log_every_n_epochs
|
|
33
56
|
self.model_name = model_name
|
|
34
57
|
self.version_name = version_name
|
|
@@ -3,12 +3,16 @@ from warnings import warn
|
|
|
3
3
|
|
|
4
4
|
import lightgbm as lgb
|
|
5
5
|
|
|
6
|
+
from snowflake.ml.experiment import utils
|
|
7
|
+
|
|
6
8
|
if TYPE_CHECKING:
|
|
7
9
|
from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
|
|
8
10
|
from snowflake.ml.model.model_signature import ModelSignature
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
14
|
+
"""LightGBM callback for automatically logging to a Snowflake ML Experiment."""
|
|
15
|
+
|
|
12
16
|
def __init__(
|
|
13
17
|
self,
|
|
14
18
|
experiment_tracking: "ExperimentTracking",
|
|
@@ -20,12 +24,33 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
|
20
24
|
version_name: Optional[str] = None,
|
|
21
25
|
model_signature: Optional["ModelSignature"] = None,
|
|
22
26
|
) -> None:
|
|
27
|
+
"""
|
|
28
|
+
Creates a new LightGBM callback.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
experiment_tracking (snowflake.ml.experiment.ExperimentTracking): The Experiment Tracking instance
|
|
32
|
+
to use for logging.
|
|
33
|
+
log_model (bool): Whether to log the model at the end of training. Default is True.
|
|
34
|
+
log_metrics (bool): Whether to log metrics during training. Default is True.
|
|
35
|
+
log_params (bool): Whether to log model parameters at the start of training. Default is True.
|
|
36
|
+
log_every_n_epochs (int): Frequency with which to log metrics. Must be positive.
|
|
37
|
+
Default is 1, logging after every iteration.
|
|
38
|
+
model_name (Optional[str]): The model name to use when logging the model.
|
|
39
|
+
If None, the model name will be derived from the experiment name.
|
|
40
|
+
version_name (Optional[str]): The model version name to use when logging the model.
|
|
41
|
+
If None, the version name will be randomly generated.
|
|
42
|
+
model_signature (Optional[snowflake.ml.model.model_signature.ModelSignature]): The model signature to use
|
|
43
|
+
when logging the model. This is required if ``log_model`` is set to True.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
ValueError: When ``log_every_n_epochs`` is not a positive integer.
|
|
47
|
+
"""
|
|
23
48
|
self._experiment_tracking = experiment_tracking
|
|
24
49
|
self.log_model = log_model
|
|
25
50
|
self.log_metrics = log_metrics
|
|
26
51
|
self.log_params = log_params
|
|
27
|
-
if log_every_n_epochs
|
|
28
|
-
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
52
|
+
if not (utils.is_integer(log_every_n_epochs) and log_every_n_epochs > 0):
|
|
53
|
+
raise ValueError("`log_every_n_epochs` must be a positive integer.")
|
|
29
54
|
self.log_every_n_epochs = log_every_n_epochs
|
|
30
55
|
self.model_name = model_name
|
|
31
56
|
self.version_name = version_name
|
|
@@ -12,6 +12,8 @@ if TYPE_CHECKING:
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
15
|
+
"""XGBoost callback for automatically logging to a Snowflake ML Experiment."""
|
|
16
|
+
|
|
15
17
|
def __init__(
|
|
16
18
|
self,
|
|
17
19
|
experiment_tracking: "ExperimentTracking",
|
|
@@ -23,12 +25,33 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
|
23
25
|
version_name: Optional[str] = None,
|
|
24
26
|
model_signature: Optional["ModelSignature"] = None,
|
|
25
27
|
) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Initialize the callback.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
experiment_tracking (snowflake.ml.experiment.ExperimentTracking): The Experiment Tracking instance
|
|
33
|
+
to use for logging.
|
|
34
|
+
log_model (bool): Whether to log the model at the end of training. Default is True.
|
|
35
|
+
log_metrics (bool): Whether to log metrics during training. Default is True.
|
|
36
|
+
log_params (bool): Whether to log model parameters at the start of training. Default is True.
|
|
37
|
+
log_every_n_epochs (int): Frequency with which to log metrics. Must be positive.
|
|
38
|
+
Default is 1, logging after every iteration.
|
|
39
|
+
model_name (Optional[str]): The model name to use when logging the model.
|
|
40
|
+
If None, the model name will be derived from the experiment name.
|
|
41
|
+
version_name (Optional[str]): The model version name to use when logging the model.
|
|
42
|
+
If None, the version name will be randomly generated.
|
|
43
|
+
model_signature (Optional[snowflake.ml.model.model_signature.ModelSignature]): The model signature to use
|
|
44
|
+
when logging the model. This is required if ``log_model`` is set to True.
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: When ``log_every_n_epochs`` is not a positive integer.
|
|
48
|
+
"""
|
|
26
49
|
self._experiment_tracking = experiment_tracking
|
|
27
50
|
self.log_model = log_model
|
|
28
51
|
self.log_metrics = log_metrics
|
|
29
52
|
self.log_params = log_params
|
|
30
|
-
if log_every_n_epochs
|
|
31
|
-
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
53
|
+
if not (utils.is_integer(log_every_n_epochs) and log_every_n_epochs > 0):
|
|
54
|
+
raise ValueError("`log_every_n_epochs` must be a positive integer.")
|
|
32
55
|
self.log_every_n_epochs = log_every_n_epochs
|
|
33
56
|
self.model_name = model_name
|
|
34
57
|
self.version_name = version_name
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import json
|
|
3
3
|
import sys
|
|
4
|
+
import warnings
|
|
4
5
|
from typing import Any, Optional, Union
|
|
5
6
|
from urllib.parse import quote
|
|
6
7
|
|
|
@@ -27,6 +28,13 @@ class ExperimentTracking:
|
|
|
27
28
|
Class to manage experiments in Snowflake.
|
|
28
29
|
"""
|
|
29
30
|
|
|
31
|
+
_instance = None
|
|
32
|
+
|
|
33
|
+
def __new__(cls, *args: Any, **kwargs: Any) -> "ExperimentTracking":
|
|
34
|
+
if cls._instance is None:
|
|
35
|
+
cls._instance = super().__new__(cls)
|
|
36
|
+
return cls._instance
|
|
37
|
+
|
|
30
38
|
def __init__(
|
|
31
39
|
self,
|
|
32
40
|
session: snowpark.Session,
|
|
@@ -36,6 +44,7 @@ class ExperimentTracking:
|
|
|
36
44
|
) -> None:
|
|
37
45
|
"""
|
|
38
46
|
Initializes experiment tracking within a pre-created schema.
|
|
47
|
+
This is a singleton class, so if an instance already exists, it will not reinitialize.
|
|
39
48
|
|
|
40
49
|
Args:
|
|
41
50
|
session: The Snowpark Session to connect with Snowflake.
|
|
@@ -47,6 +56,21 @@ class ExperimentTracking:
|
|
|
47
56
|
Raises:
|
|
48
57
|
ValueError: If no database is provided and no active database exists in the session.
|
|
49
58
|
"""
|
|
59
|
+
if hasattr(self, "_initialized"):
|
|
60
|
+
warnings.warn(
|
|
61
|
+
"ExperimentTracking is a singleton class. Reusing the existing instance, which has the setting:\n"
|
|
62
|
+
f" Database: {self._database_name}, Schema: {self._schema_name}\n"
|
|
63
|
+
"To change the database or schema, use the database_name and schema_name arguments to set_experiment.",
|
|
64
|
+
UserWarning,
|
|
65
|
+
stacklevel=2,
|
|
66
|
+
)
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
# Declare types for mypy
|
|
70
|
+
self._database_name: sql_identifier.SqlIdentifier
|
|
71
|
+
self._schema_name: sql_identifier.SqlIdentifier
|
|
72
|
+
self._sql_client: sql_client.ExperimentTrackingSQLClient
|
|
73
|
+
|
|
50
74
|
if database_name:
|
|
51
75
|
self._database_name = sql_identifier.SqlIdentifier(database_name)
|
|
52
76
|
elif session_db := session.get_current_database():
|
|
@@ -78,6 +102,8 @@ class ExperimentTracking:
|
|
|
78
102
|
# The run in context
|
|
79
103
|
self._run: Optional[entities.Run] = None
|
|
80
104
|
|
|
105
|
+
self._initialized = True
|
|
106
|
+
|
|
81
107
|
def __getstate__(self) -> dict[str, Any]:
|
|
82
108
|
parent_state = (
|
|
83
109
|
super().__getstate__() # type: ignore[misc] # object.__getstate__ appears in 3.11
|
|
@@ -116,19 +142,40 @@ class ExperimentTracking:
|
|
|
116
142
|
def set_experiment(
|
|
117
143
|
self,
|
|
118
144
|
experiment_name: str,
|
|
145
|
+
database_name: Optional[str] = None,
|
|
146
|
+
schema_name: Optional[str] = None,
|
|
119
147
|
) -> entities.Experiment:
|
|
120
148
|
"""
|
|
121
149
|
Set the experiment in context. Creates a new experiment if it doesn't exist.
|
|
122
150
|
|
|
123
151
|
Args:
|
|
124
152
|
experiment_name: The name of the experiment.
|
|
153
|
+
database_name: The name of the database. If None, reuse the current database. Defaults to None.
|
|
154
|
+
schema_name: The name of the schema. If None, the behavior depends on whether `database_name` is specified.
|
|
155
|
+
If `database_name` is specified, the schema is set to "PUBLIC".
|
|
156
|
+
If `database_name` is not specified, reuse the current schema. Defaults to None.
|
|
125
157
|
|
|
126
158
|
Returns:
|
|
127
159
|
Experiment: The experiment that was set.
|
|
128
160
|
"""
|
|
161
|
+
if database_name is not None:
|
|
162
|
+
if schema_name is None:
|
|
163
|
+
schema_name = "PUBLIC"
|
|
164
|
+
database_name = (
|
|
165
|
+
sql_identifier.SqlIdentifier(database_name) if database_name is not None else self._database_name
|
|
166
|
+
)
|
|
167
|
+
schema_name = sql_identifier.SqlIdentifier(schema_name) if schema_name is not None else self._schema_name
|
|
168
|
+
|
|
129
169
|
experiment_name = sql_identifier.SqlIdentifier(experiment_name)
|
|
130
|
-
if
|
|
170
|
+
if (
|
|
171
|
+
self._experiment
|
|
172
|
+
and self._experiment.name == experiment_name
|
|
173
|
+
and self._database_name == database_name
|
|
174
|
+
and self._schema_name == schema_name
|
|
175
|
+
):
|
|
131
176
|
return self._experiment
|
|
177
|
+
|
|
178
|
+
self._update_database_and_schema(database_name, schema_name)
|
|
132
179
|
self._sql_client.create_experiment(
|
|
133
180
|
experiment_name=experiment_name,
|
|
134
181
|
creation_mode=sql_client_utils.CreationMode(if_not_exists=True),
|
|
@@ -140,15 +187,42 @@ class ExperimentTracking:
|
|
|
140
187
|
def delete_experiment(
|
|
141
188
|
self,
|
|
142
189
|
experiment_name: str,
|
|
190
|
+
database_name: Optional[str] = None,
|
|
191
|
+
schema_name: Optional[str] = None,
|
|
143
192
|
) -> None:
|
|
144
193
|
"""
|
|
145
194
|
Delete an experiment.
|
|
146
195
|
|
|
147
196
|
Args:
|
|
148
197
|
experiment_name: The name of the experiment.
|
|
198
|
+
database_name: The name of the database. If None, reuse the current database.
|
|
199
|
+
Must be specified if `schema_name` is specified. Defaults to None.
|
|
200
|
+
schema_name: The name of the schema. If None, reuse the current schema.
|
|
201
|
+
Must be specified if `database_name` is specified. Defaults to None.
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
ValueError: If database_name is specified but schema_name is not.
|
|
149
205
|
"""
|
|
150
|
-
|
|
151
|
-
|
|
206
|
+
if (database_name is None) ^ (schema_name is None): # if only one of database_name and schema_name is set
|
|
207
|
+
raise ValueError(
|
|
208
|
+
"If one of database_name and schema_name is specified, the other one must also be specified."
|
|
209
|
+
)
|
|
210
|
+
database_name = (
|
|
211
|
+
sql_identifier.SqlIdentifier(database_name) if database_name is not None else self._database_name
|
|
212
|
+
)
|
|
213
|
+
schema_name = sql_identifier.SqlIdentifier(schema_name) if schema_name is not None else self._schema_name
|
|
214
|
+
|
|
215
|
+
self._sql_client.drop_experiment(
|
|
216
|
+
database_name=database_name,
|
|
217
|
+
schema_name=schema_name,
|
|
218
|
+
experiment_name=sql_identifier.SqlIdentifier(experiment_name),
|
|
219
|
+
)
|
|
220
|
+
if (
|
|
221
|
+
self._experiment
|
|
222
|
+
and self._experiment.name == experiment_name
|
|
223
|
+
and self._database_name == database_name
|
|
224
|
+
and self._schema_name == schema_name
|
|
225
|
+
):
|
|
152
226
|
self._experiment = None
|
|
153
227
|
self._run = None
|
|
154
228
|
|
|
@@ -451,6 +525,22 @@ class ExperimentTracking:
|
|
|
451
525
|
return sql_identifier.SqlIdentifier(run_name)
|
|
452
526
|
raise RuntimeError("Random run name generation failed.")
|
|
453
527
|
|
|
528
|
+
def _update_database_and_schema(
|
|
529
|
+
self, database_name: sql_identifier.SqlIdentifier, schema_name: sql_identifier.SqlIdentifier
|
|
530
|
+
) -> None:
|
|
531
|
+
self._database_name = database_name
|
|
532
|
+
self._schema_name = schema_name
|
|
533
|
+
self._sql_client = sql_client.ExperimentTrackingSQLClient(
|
|
534
|
+
session=self._session,
|
|
535
|
+
database_name=database_name,
|
|
536
|
+
schema_name=schema_name,
|
|
537
|
+
)
|
|
538
|
+
self._registry = registry.Registry(
|
|
539
|
+
session=self._session,
|
|
540
|
+
database_name=database_name,
|
|
541
|
+
schema_name=schema_name,
|
|
542
|
+
)
|
|
543
|
+
|
|
454
544
|
def _print_urls(
|
|
455
545
|
self,
|
|
456
546
|
experiment_name: sql_identifier.SqlIdentifier,
|
snowflake/ml/experiment/utils.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import numbers
|
|
1
2
|
from typing import Any, Union
|
|
2
3
|
|
|
3
4
|
|
|
@@ -12,3 +13,8 @@ def flatten_nested_params(params: Union[list[Any], dict[str, Any]], prefix: str
|
|
|
12
13
|
else:
|
|
13
14
|
flat_params[new_prefix] = value
|
|
14
15
|
return flat_params
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def is_integer(value: Any) -> bool:
|
|
19
|
+
"""Check if the given value is an integer, excluding booleans."""
|
|
20
|
+
return isinstance(value, numbers.Integral) and not isinstance(value, bool)
|