snowflake-ml-python 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +16 -0
- snowflake/ml/_internal/platform_capabilities.py +36 -0
- snowflake/ml/_internal/telemetry.py +56 -7
- snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
- snowflake/ml/data/data_connector.py +103 -1
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
- snowflake/ml/experiment/_entities/run.py +15 -0
- snowflake/ml/experiment/callback/keras.py +25 -2
- snowflake/ml/experiment/callback/lightgbm.py +27 -2
- snowflake/ml/experiment/callback/xgboost.py +25 -2
- snowflake/ml/experiment/experiment_tracking.py +123 -13
- snowflake/ml/experiment/utils.py +6 -0
- snowflake/ml/feature_store/access_manager.py +1 -0
- snowflake/ml/feature_store/feature_store.py +1 -1
- snowflake/ml/feature_store/feature_view.py +34 -24
- snowflake/ml/jobs/_interop/protocols.py +3 -0
- snowflake/ml/jobs/_utils/feature_flags.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +360 -357
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
- snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
- snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
- snowflake/ml/jobs/_utils/spec_utils.py +2 -406
- snowflake/ml/jobs/_utils/stage_utils.py +22 -1
- snowflake/ml/jobs/_utils/types.py +14 -7
- snowflake/ml/jobs/job.py +8 -9
- snowflake/ml/jobs/manager.py +64 -129
- snowflake/ml/model/_client/model/inference_engine_utils.py +8 -4
- snowflake/ml/model/_client/model/model_version_impl.py +109 -28
- snowflake/ml/model/_client/ops/model_ops.py +32 -6
- snowflake/ml/model/_client/ops/service_ops.py +9 -4
- snowflake/ml/model/_client/sql/service.py +69 -2
- snowflake/ml/model/_packager/model_handler.py +8 -2
- snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
- snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_signatures/core.py +305 -8
- snowflake/ml/model/_signatures/utils.py +13 -4
- snowflake/ml/model/compute_pool.py +2 -0
- snowflake/ml/model/models/huggingface.py +285 -0
- snowflake/ml/model/models/huggingface_pipeline.py +25 -215
- snowflake/ml/model/type_hints.py +5 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
- snowflake/ml/utils/html_utils.py +67 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/METADATA +94 -7
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/RECORD +52 -48
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@ from snowflake.ml import version as snowml_version
|
|
|
16
16
|
from snowflake.ml._internal import env as snowml_env, relax_version_strategy
|
|
17
17
|
from snowflake.ml._internal.utils import query_result_checker
|
|
18
18
|
from snowflake.snowpark import context, exceptions, session
|
|
19
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class CONDA_OS(Enum):
|
|
@@ -38,6 +39,21 @@ SNOWPARK_ML_PKG_NAME = "snowflake-ml-python"
|
|
|
38
39
|
SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
|
|
39
40
|
|
|
40
41
|
|
|
42
|
+
def get_execution_context() -> str:
|
|
43
|
+
"""Detect execution context: EXTERNAL, SPCS, or SPROC.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
str: The execution context - "SPROC" if running in a stored procedure,
|
|
47
|
+
"SPCS" if running in SPCS ML runtime, "EXTERNAL" otherwise.
|
|
48
|
+
"""
|
|
49
|
+
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
|
50
|
+
return "SPROC"
|
|
51
|
+
elif snowml_env.IN_ML_RUNTIME:
|
|
52
|
+
return "SPCS"
|
|
53
|
+
else:
|
|
54
|
+
return "EXTERNAL"
|
|
55
|
+
|
|
56
|
+
|
|
41
57
|
def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
|
|
42
58
|
"""Validate the input pip requirement string according to PEP 508.
|
|
43
59
|
|
|
@@ -18,6 +18,8 @@ logger = logging.getLogger(__name__)
|
|
|
18
18
|
LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
|
|
19
19
|
INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
|
|
20
20
|
SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST = "SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST"
|
|
21
|
+
ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS = "ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS"
|
|
22
|
+
FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
class PlatformCapabilities:
|
|
@@ -80,6 +82,12 @@ class PlatformCapabilities:
|
|
|
80
82
|
def is_live_commit_enabled(self) -> bool:
|
|
81
83
|
return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
|
|
82
84
|
|
|
85
|
+
def is_model_method_signature_parameters_enabled(self) -> bool:
|
|
86
|
+
return self._get_bool_feature(ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS, False)
|
|
87
|
+
|
|
88
|
+
def is_inference_autocapture_enabled(self) -> bool:
|
|
89
|
+
return self._is_feature_enabled(FEATURE_MODEL_INFERENCE_AUTOCAPTURE)
|
|
90
|
+
|
|
83
91
|
@staticmethod
|
|
84
92
|
def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
|
|
85
93
|
try:
|
|
@@ -182,3 +190,31 @@ class PlatformCapabilities:
|
|
|
182
190
|
f"current={current_version}, feature={feature_version}, enabled={result}"
|
|
183
191
|
)
|
|
184
192
|
return result
|
|
193
|
+
|
|
194
|
+
def _is_feature_enabled(self, feature_name: str) -> bool:
|
|
195
|
+
"""Check if the feature parameter value belongs to enabled values.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
feature_name: The name of the feature to retrieve.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
bool: True if the value is "ENABLED" or "ENABLED_PUBLIC_PREVIEW",
|
|
202
|
+
False if the value is "DISABLED", "DISABLED_PRIVATE_PREVIEW", or not set.
|
|
203
|
+
|
|
204
|
+
Raises:
|
|
205
|
+
ValueError: If the feature value is set but not one of the recognized values.
|
|
206
|
+
"""
|
|
207
|
+
value = self.features.get(feature_name)
|
|
208
|
+
if value is None:
|
|
209
|
+
logger.debug(f"Feature {feature_name} not found.")
|
|
210
|
+
return False
|
|
211
|
+
|
|
212
|
+
if isinstance(value, str):
|
|
213
|
+
value_str = str(value)
|
|
214
|
+
if value_str.upper() in ["ENABLED", "ENABLED_PUBLIC_PREVIEW"]:
|
|
215
|
+
return True
|
|
216
|
+
elif value_str.upper() in ["DISABLED", "DISABLED_PRIVATE_PREVIEW"]:
|
|
217
|
+
return False
|
|
218
|
+
else:
|
|
219
|
+
raise ValueError(f"Invalid feature parameter value: {value} for feature {feature_name}")
|
|
220
|
+
raise ValueError(f"Invalid feature parameter string value: {value} for feature {feature_name}")
|
|
@@ -16,7 +16,7 @@ from typing_extensions import ParamSpec
|
|
|
16
16
|
from snowflake import connector
|
|
17
17
|
from snowflake.connector import connect, telemetry as connector_telemetry, time_util
|
|
18
18
|
from snowflake.ml import version as snowml_version
|
|
19
|
-
from snowflake.ml._internal import env
|
|
19
|
+
from snowflake.ml._internal import env, env_utils
|
|
20
20
|
from snowflake.ml._internal.exceptions import (
|
|
21
21
|
error_codes,
|
|
22
22
|
exceptions as snowml_exceptions,
|
|
@@ -37,6 +37,22 @@ _CONNECTION_TYPES = {
|
|
|
37
37
|
_Args = ParamSpec("_Args")
|
|
38
38
|
_ReturnValue = TypeVar("_ReturnValue")
|
|
39
39
|
|
|
40
|
+
_conn: Optional[connector.SnowflakeConnection] = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def clear_cached_conn() -> None:
|
|
44
|
+
"""Clear the cached Snowflake connection. Primarily for testing purposes."""
|
|
45
|
+
global _conn
|
|
46
|
+
if _conn is not None and _conn.is_valid():
|
|
47
|
+
_conn.close()
|
|
48
|
+
_conn = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_cached_conn() -> Optional[connector.SnowflakeConnection]:
|
|
52
|
+
"""Get the cached Snowflake connection. Primarily for testing purposes."""
|
|
53
|
+
global _conn
|
|
54
|
+
return _conn
|
|
55
|
+
|
|
40
56
|
|
|
41
57
|
def _get_login_token() -> Union[str, bytes]:
|
|
42
58
|
with open("/snowflake/session/token") as f:
|
|
@@ -44,7 +60,11 @@ def _get_login_token() -> Union[str, bytes]:
|
|
|
44
60
|
|
|
45
61
|
|
|
46
62
|
def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
|
|
47
|
-
|
|
63
|
+
global _conn
|
|
64
|
+
if _conn is not None and _conn.is_valid():
|
|
65
|
+
return _conn
|
|
66
|
+
|
|
67
|
+
conn: Optional[connector.SnowflakeConnection] = None
|
|
48
68
|
if os.getenv("SNOWFLAKE_HOST") is not None and os.getenv("SNOWFLAKE_ACCOUNT") is not None:
|
|
49
69
|
try:
|
|
50
70
|
conn = connect(
|
|
@@ -66,6 +86,13 @@ def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
|
|
|
66
86
|
# Failed to get an active session. No connection available.
|
|
67
87
|
pass
|
|
68
88
|
|
|
89
|
+
# cache the connection if it's a SnowflakeConnection. there is a behavior at runtime where it could be a
|
|
90
|
+
# StoredProcConnection perhaps incorrect type hinting somewhere
|
|
91
|
+
if isinstance(conn, connector.SnowflakeConnection):
|
|
92
|
+
# if _conn was expired, we need to copy telemetry data to new connection
|
|
93
|
+
if _conn is not None and conn is not None:
|
|
94
|
+
conn._telemetry._log_batch.extend(_conn._telemetry._log_batch)
|
|
95
|
+
_conn = conn
|
|
69
96
|
return conn
|
|
70
97
|
|
|
71
98
|
|
|
@@ -113,6 +140,13 @@ class TelemetryField(enum.Enum):
|
|
|
113
140
|
FUNC_CAT_USAGE = "usage"
|
|
114
141
|
|
|
115
142
|
|
|
143
|
+
@enum.unique
|
|
144
|
+
class CustomTagKey(enum.Enum):
|
|
145
|
+
"""Keys for custom tags in telemetry."""
|
|
146
|
+
|
|
147
|
+
EXECUTION_CONTEXT = "execution_context"
|
|
148
|
+
|
|
149
|
+
|
|
116
150
|
class _TelemetrySourceType(enum.Enum):
|
|
117
151
|
# Automatically inferred telemetry/statement parameters
|
|
118
152
|
AUTO_TELEMETRY = "SNOWML_AUTO_TELEMETRY"
|
|
@@ -441,6 +475,7 @@ def send_api_usage_telemetry(
|
|
|
441
475
|
sfqids_extractor: Optional[Callable[..., list[str]]] = None,
|
|
442
476
|
subproject_extractor: Optional[Callable[[Any], str]] = None,
|
|
443
477
|
custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
|
|
478
|
+
log_execution_context: bool = True,
|
|
444
479
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
|
|
445
480
|
"""
|
|
446
481
|
Decorator that sends API usage telemetry and adds function usage statement parameters to the dataframe returned by
|
|
@@ -455,6 +490,8 @@ def send_api_usage_telemetry(
|
|
|
455
490
|
sfqids_extractor: Extract sfqids from `self`.
|
|
456
491
|
subproject_extractor: Extract subproject at runtime from `self`.
|
|
457
492
|
custom_tags: Custom tags.
|
|
493
|
+
log_execution_context: If True, automatically detect and log execution context
|
|
494
|
+
(EXTERNAL, SPCS, or SPROC) in custom_tags.
|
|
458
495
|
|
|
459
496
|
Returns:
|
|
460
497
|
Decorator that sends function usage telemetry for any call to the decorated function.
|
|
@@ -495,6 +532,11 @@ def send_api_usage_telemetry(
|
|
|
495
532
|
if subproject_extractor is not None:
|
|
496
533
|
subproject_name = subproject_extractor(args[0])
|
|
497
534
|
|
|
535
|
+
# Add execution context if enabled
|
|
536
|
+
final_custom_tags = {**custom_tags} if custom_tags is not None else {}
|
|
537
|
+
if log_execution_context:
|
|
538
|
+
final_custom_tags[CustomTagKey.EXECUTION_CONTEXT.value] = env_utils.get_execution_context()
|
|
539
|
+
|
|
498
540
|
statement_params = get_function_usage_statement_params(
|
|
499
541
|
project=project,
|
|
500
542
|
subproject=subproject_name,
|
|
@@ -502,7 +544,7 @@ def send_api_usage_telemetry(
|
|
|
502
544
|
function_name=_get_full_func_name(func),
|
|
503
545
|
function_parameters=params,
|
|
504
546
|
api_calls=api_calls,
|
|
505
|
-
custom_tags=
|
|
547
|
+
custom_tags=final_custom_tags,
|
|
506
548
|
)
|
|
507
549
|
|
|
508
550
|
def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: dict[str, Any]) -> _ReturnValue:
|
|
@@ -538,7 +580,10 @@ def send_api_usage_telemetry(
|
|
|
538
580
|
if conn_attr_name:
|
|
539
581
|
# raise AttributeError if conn attribute does not exist in `self`
|
|
540
582
|
conn = operator.attrgetter(conn_attr_name)(args[0])
|
|
541
|
-
if not isinstance(
|
|
583
|
+
if not isinstance(
|
|
584
|
+
conn,
|
|
585
|
+
_CONNECTION_TYPES.get(type(conn).__name__, connector.SnowflakeConnection),
|
|
586
|
+
):
|
|
542
587
|
raise TypeError(
|
|
543
588
|
f"Expected a conn object of type {' or '.join(_CONNECTION_TYPES.keys())} but got {type(conn)}"
|
|
544
589
|
)
|
|
@@ -560,7 +605,7 @@ def send_api_usage_telemetry(
|
|
|
560
605
|
func_params=params,
|
|
561
606
|
api_calls=api_calls,
|
|
562
607
|
sfqids=sfqids,
|
|
563
|
-
custom_tags=
|
|
608
|
+
custom_tags=final_custom_tags,
|
|
564
609
|
)
|
|
565
610
|
try:
|
|
566
611
|
return ctx.run(execute_func_with_statement_params)
|
|
@@ -571,7 +616,8 @@ def send_api_usage_telemetry(
|
|
|
571
616
|
raise
|
|
572
617
|
if isinstance(e, snowpark_exceptions.SnowparkClientException):
|
|
573
618
|
me = snowml_exceptions.SnowflakeMLException(
|
|
574
|
-
error_code=error_codes.INTERNAL_SNOWPARK_ERROR,
|
|
619
|
+
error_code=error_codes.INTERNAL_SNOWPARK_ERROR,
|
|
620
|
+
original_exception=e,
|
|
575
621
|
)
|
|
576
622
|
else:
|
|
577
623
|
me = snowml_exceptions.SnowflakeMLException(
|
|
@@ -627,7 +673,10 @@ def _get_full_func_name(func: Callable[..., Any]) -> str:
|
|
|
627
673
|
|
|
628
674
|
|
|
629
675
|
def _get_func_params(
|
|
630
|
-
func: Callable[..., Any],
|
|
676
|
+
func: Callable[..., Any],
|
|
677
|
+
func_params_to_log: Optional[Iterable[str]],
|
|
678
|
+
args: Any,
|
|
679
|
+
kwargs: Any,
|
|
631
680
|
) -> dict[str, Any]:
|
|
632
681
|
"""
|
|
633
682
|
Get function parameters.
|
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
import collections
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
5
|
+
import re
|
|
4
6
|
import time
|
|
5
7
|
from typing import TYPE_CHECKING, Any, Deque, Iterator, Optional, Sequence, Union
|
|
6
8
|
|
|
@@ -165,8 +167,71 @@ class ArrowIngestor(data_ingestor.DataIngestor, mixins.SerializableSessionMixin)
|
|
|
165
167
|
# Re-shuffle input files on each iteration start
|
|
166
168
|
if shuffle:
|
|
167
169
|
np.random.shuffle(sources)
|
|
168
|
-
|
|
169
|
-
|
|
170
|
+
try:
|
|
171
|
+
pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
|
|
172
|
+
return pa_dataset
|
|
173
|
+
except Exception as e:
|
|
174
|
+
self._tmp_debug_parquet_invalid(e, sources)
|
|
175
|
+
|
|
176
|
+
def _tmp_debug_parquet_invalid(self, e: Exception, sources: list[Any]) -> None:
|
|
177
|
+
# Attach rich debug info to help diagnose intermittent Parquet footer/magic byte errors
|
|
178
|
+
debug_parts: list[str] = []
|
|
179
|
+
debug_parts.append("SNOWML DEBUG: Failed to construct Arrow Dataset")
|
|
180
|
+
debug_parts.append(
|
|
181
|
+
"SNOWML DEBUG: " f"data_sources_count={len(self._data_sources)} " f"resolved_sources_count={len(sources)}"
|
|
182
|
+
)
|
|
183
|
+
# Try to include the exact file path mentioned by pyarrow, if present
|
|
184
|
+
error_text = str(e)
|
|
185
|
+
snow_paths: list[str] = []
|
|
186
|
+
try:
|
|
187
|
+
# Extract snow://... tokens possibly wrapped in quotes
|
|
188
|
+
for match in re.finditer(r'(snow://[^\s\'"]+)', error_text):
|
|
189
|
+
token = match.group(1).rstrip(").,;]")
|
|
190
|
+
snow_paths.append(token)
|
|
191
|
+
except Exception:
|
|
192
|
+
pass
|
|
193
|
+
fs = self._kwargs.get("filesystem")
|
|
194
|
+
if fs is not None:
|
|
195
|
+
# Always include a directory listing with sizes for context
|
|
196
|
+
try:
|
|
197
|
+
debug_parts.append("SNOWML DEBUG: Listing resolved sources with sizes:")
|
|
198
|
+
for s in sources:
|
|
199
|
+
try:
|
|
200
|
+
info = fs.info(s)
|
|
201
|
+
size = info.get("size", None)
|
|
202
|
+
md5 = info.get("md5", None)
|
|
203
|
+
debug_parts.append(f" - {s} size={size} md5={md5}")
|
|
204
|
+
except Exception as le:
|
|
205
|
+
debug_parts.append(f" - {s} info_failed={le}")
|
|
206
|
+
except Exception as le:
|
|
207
|
+
debug_parts.append(f"SNOWML DEBUG: listing sources failed: {le}")
|
|
208
|
+
# If pyarrow referenced a specific file, dump its full contents (base64) for inspection
|
|
209
|
+
for path in snow_paths[:1]: # usually only one path appears in the message
|
|
210
|
+
try:
|
|
211
|
+
info = fs.info(path)
|
|
212
|
+
size = info.get("size", None)
|
|
213
|
+
debug_parts.append(f"SNOWML DEBUG: Inspecting referenced file: {path} size={size}")
|
|
214
|
+
with fs.open(path, "rb") as f:
|
|
215
|
+
content = f.read()
|
|
216
|
+
magic_head = content[:4]
|
|
217
|
+
magic_tail = content[-4:] if content else b""
|
|
218
|
+
looks_like_parquet = (magic_head == b"PAR1") and (magic_tail == b"PAR1")
|
|
219
|
+
debug_parts.append(
|
|
220
|
+
"SNOWML DEBUG: "
|
|
221
|
+
f"file_magic_head={magic_head!r} "
|
|
222
|
+
f"file_magic_tail={magic_tail!r} "
|
|
223
|
+
f"parquet_magic_detected={looks_like_parquet}"
|
|
224
|
+
)
|
|
225
|
+
b64 = base64.b64encode(content).decode("ascii")
|
|
226
|
+
debug_parts.append("SNOWML DEBUG: file_content_base64 (entire file):")
|
|
227
|
+
debug_parts.append(b64)
|
|
228
|
+
except Exception as fe:
|
|
229
|
+
debug_parts.append(f"SNOWML DEBUG: failed to read referenced file {path}: {fe}")
|
|
230
|
+
else:
|
|
231
|
+
debug_parts.append("SNOWML DEBUG: No filesystem available; cannot inspect files")
|
|
232
|
+
debug_message = "\n".join(debug_parts)
|
|
233
|
+
# Re-raise with augmented message to surface in stacktrace
|
|
234
|
+
raise RuntimeError(f"{e}\n{debug_message}") from e
|
|
170
235
|
|
|
171
236
|
def _get_batches_from_buffer(self, batch_size: int) -> dict[str, npt.NDArray[Any]]:
|
|
172
237
|
"""Generate new batches from the existing record batch buffer."""
|
|
@@ -1,5 +1,15 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import (
|
|
3
|
+
TYPE_CHECKING,
|
|
4
|
+
Any,
|
|
5
|
+
Generator,
|
|
6
|
+
Literal,
|
|
7
|
+
Optional,
|
|
8
|
+
Sequence,
|
|
9
|
+
TypeVar,
|
|
10
|
+
Union,
|
|
11
|
+
overload,
|
|
12
|
+
)
|
|
3
13
|
|
|
4
14
|
import numpy.typing as npt
|
|
5
15
|
from typing_extensions import deprecated
|
|
@@ -11,6 +21,7 @@ from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
|
|
|
11
21
|
from snowflake.snowpark import context as sp_context
|
|
12
22
|
|
|
13
23
|
if TYPE_CHECKING:
|
|
24
|
+
import datasets as hf_datasets
|
|
14
25
|
import pandas as pd
|
|
15
26
|
import ray
|
|
16
27
|
import tensorflow as tf
|
|
@@ -257,6 +268,97 @@ class DataConnector:
|
|
|
257
268
|
except ImportError as e:
|
|
258
269
|
raise ImportError("Ray is not installed, please install ray in your local environment.") from e
|
|
259
270
|
|
|
271
|
+
@overload
|
|
272
|
+
def to_huggingface_dataset(
|
|
273
|
+
self,
|
|
274
|
+
*,
|
|
275
|
+
streaming: Literal[False] = ...,
|
|
276
|
+
limit: Optional[int] = ...,
|
|
277
|
+
) -> "hf_datasets.Dataset":
|
|
278
|
+
...
|
|
279
|
+
|
|
280
|
+
@overload
|
|
281
|
+
def to_huggingface_dataset(
|
|
282
|
+
self,
|
|
283
|
+
*,
|
|
284
|
+
streaming: Literal[True],
|
|
285
|
+
limit: Optional[int] = ...,
|
|
286
|
+
batch_size: int = ...,
|
|
287
|
+
shuffle: bool = ...,
|
|
288
|
+
drop_last_batch: bool = ...,
|
|
289
|
+
) -> "hf_datasets.IterableDataset":
|
|
290
|
+
...
|
|
291
|
+
|
|
292
|
+
@telemetry.send_api_usage_telemetry(
|
|
293
|
+
project=_PROJECT,
|
|
294
|
+
subproject_extractor=lambda self: type(self).__name__,
|
|
295
|
+
func_params_to_log=["streaming", "limit", "batch_size", "shuffle", "drop_last_batch"],
|
|
296
|
+
)
|
|
297
|
+
def to_huggingface_dataset(
|
|
298
|
+
self,
|
|
299
|
+
*,
|
|
300
|
+
streaming: bool = False,
|
|
301
|
+
limit: Optional[int] = None,
|
|
302
|
+
batch_size: int = 1024,
|
|
303
|
+
shuffle: bool = False,
|
|
304
|
+
drop_last_batch: bool = False,
|
|
305
|
+
) -> "Union[hf_datasets.Dataset, hf_datasets.IterableDataset]":
|
|
306
|
+
"""Retrieve the Snowflake data as a HuggingFace Dataset.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
streaming: If True, returns an IterableDataset that streams data in batches.
|
|
310
|
+
If False (default), returns an in-memory Dataset.
|
|
311
|
+
limit: Maximum number of rows to load. If None, loads all rows.
|
|
312
|
+
batch_size: Size of batches for internal data retrieval.
|
|
313
|
+
shuffle: Whether to shuffle the data. If True, files will be shuffled and rows
|
|
314
|
+
in each file will also be shuffled.
|
|
315
|
+
drop_last_batch: Whether to drop the last batch if it's smaller than batch_size.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
A HuggingFace Dataset (in-memory) or IterableDataset (streaming).
|
|
319
|
+
"""
|
|
320
|
+
import datasets as hf_datasets
|
|
321
|
+
|
|
322
|
+
if streaming:
|
|
323
|
+
return self._to_huggingface_iterable_dataset(
|
|
324
|
+
limit=limit,
|
|
325
|
+
batch_size=batch_size,
|
|
326
|
+
shuffle=shuffle,
|
|
327
|
+
drop_last_batch=drop_last_batch,
|
|
328
|
+
)
|
|
329
|
+
else:
|
|
330
|
+
return hf_datasets.Dataset.from_pandas(self._ingestor.to_pandas(limit))
|
|
331
|
+
|
|
332
|
+
def _to_huggingface_iterable_dataset(
|
|
333
|
+
self,
|
|
334
|
+
*,
|
|
335
|
+
limit: Optional[int],
|
|
336
|
+
batch_size: int,
|
|
337
|
+
shuffle: bool,
|
|
338
|
+
drop_last_batch: bool,
|
|
339
|
+
) -> "hf_datasets.IterableDataset":
|
|
340
|
+
"""Create a HuggingFace IterableDataset that streams data in batches."""
|
|
341
|
+
import datasets as hf_datasets
|
|
342
|
+
|
|
343
|
+
def generator() -> Generator[dict[str, Any], None, None]:
|
|
344
|
+
rows_yielded = 0
|
|
345
|
+
for batch in self._ingestor.to_batches(batch_size, shuffle, drop_last_batch):
|
|
346
|
+
# Yield individual rows from each batch
|
|
347
|
+
num_rows = len(next(iter(batch.values())))
|
|
348
|
+
for i in range(num_rows):
|
|
349
|
+
if limit is not None and rows_yielded >= limit:
|
|
350
|
+
return
|
|
351
|
+
yield {k: v[i].item() if hasattr(v[i], "item") else v[i] for k, v in batch.items()}
|
|
352
|
+
rows_yielded += 1
|
|
353
|
+
|
|
354
|
+
result = hf_datasets.IterableDataset.from_generator(generator)
|
|
355
|
+
# In datasets >= 3.x, from_generator returns IterableDatasetDict
|
|
356
|
+
# We need to extract the IterableDataset from the dict
|
|
357
|
+
if hasattr(hf_datasets, "IterableDatasetDict") and isinstance(result, hf_datasets.IterableDatasetDict):
|
|
358
|
+
# Get the first (and only) dataset from the dict
|
|
359
|
+
result = next(iter(result.values()))
|
|
360
|
+
return result
|
|
361
|
+
|
|
260
362
|
|
|
261
363
|
# Switch to use Runtime's Data Ingester if running in ML runtime
|
|
262
364
|
# Fail silently if the data ingester is not found
|
|
@@ -41,8 +41,14 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
41
41
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
42
42
|
|
|
43
43
|
@telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
|
|
44
|
-
def drop_experiment(
|
|
45
|
-
|
|
44
|
+
def drop_experiment(
|
|
45
|
+
self,
|
|
46
|
+
*,
|
|
47
|
+
database_name: sql_identifier.SqlIdentifier,
|
|
48
|
+
schema_name: sql_identifier.SqlIdentifier,
|
|
49
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
|
50
|
+
) -> None:
|
|
51
|
+
experiment_fqn = self.fully_qualified_object_name(database_name, schema_name, experiment_name)
|
|
46
52
|
query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
|
|
47
53
|
expected_rows=1, expected_cols=1
|
|
48
54
|
).validate()
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import types
|
|
2
|
+
import warnings
|
|
2
3
|
from typing import TYPE_CHECKING, Optional
|
|
3
4
|
|
|
4
5
|
from snowflake.ml._internal.utils import sql_identifier
|
|
@@ -7,6 +8,8 @@ from snowflake.ml.experiment import _experiment_info as experiment_info
|
|
|
7
8
|
if TYPE_CHECKING:
|
|
8
9
|
from snowflake.ml.experiment import experiment_tracking
|
|
9
10
|
|
|
11
|
+
METADATA_SIZE_WARNING_MESSAGE = "It is likely that no further metrics or parameters will be logged for this run."
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
class Run:
|
|
12
15
|
def __init__(
|
|
@@ -20,6 +23,9 @@ class Run:
|
|
|
20
23
|
self.experiment_name = experiment_name
|
|
21
24
|
self.name = run_name
|
|
22
25
|
|
|
26
|
+
# Whether we've already shown the user a warning about exceeding the run metadata size limit.
|
|
27
|
+
self._warned_about_metadata_size = False
|
|
28
|
+
|
|
23
29
|
self._patcher = experiment_info.ExperimentInfoPatcher(
|
|
24
30
|
experiment_info=self._get_experiment_info(),
|
|
25
31
|
)
|
|
@@ -45,3 +51,12 @@ class Run:
|
|
|
45
51
|
),
|
|
46
52
|
run_name=self.name.identifier(),
|
|
47
53
|
)
|
|
54
|
+
|
|
55
|
+
def _warn_about_run_metadata_size(self, sql_error_msg: str) -> None:
|
|
56
|
+
if not self._warned_about_metadata_size:
|
|
57
|
+
warnings.warn(
|
|
58
|
+
f"{sql_error_msg}. {METADATA_SIZE_WARNING_MESSAGE}",
|
|
59
|
+
RuntimeWarning,
|
|
60
|
+
stacklevel=2,
|
|
61
|
+
)
|
|
62
|
+
self._warned_about_metadata_size = True
|
|
@@ -12,6 +12,8 @@ if TYPE_CHECKING:
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class SnowflakeKerasCallback(keras.callbacks.Callback):
|
|
15
|
+
"""Keras callback for automatically logging to a Snowflake ML Experiment."""
|
|
16
|
+
|
|
15
17
|
def __init__(
|
|
16
18
|
self,
|
|
17
19
|
experiment_tracking: "ExperimentTracking",
|
|
@@ -23,12 +25,33 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
|
|
|
23
25
|
version_name: Optional[str] = None,
|
|
24
26
|
model_signature: Optional["ModelSignature"] = None,
|
|
25
27
|
) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Creates a new Keras callback.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
experiment_tracking (snowflake.ml.experiment.ExperimentTracking): The Experiment Tracking instance
|
|
33
|
+
to use for logging.
|
|
34
|
+
log_model (bool): Whether to log the model at the end of training. Default is True.
|
|
35
|
+
log_metrics (bool): Whether to log metrics during training. Default is True.
|
|
36
|
+
log_params (bool): Whether to log model parameters at the start of training. Default is True.
|
|
37
|
+
log_every_n_epochs (int): Frequency with which to log metrics. Must be positive.
|
|
38
|
+
Default is 1, logging after every epoch.
|
|
39
|
+
model_name (Optional[str]): The model name to use when logging the model.
|
|
40
|
+
If None, the model name will be derived from the experiment name.
|
|
41
|
+
version_name (Optional[str]): The model version name to use when logging the model.
|
|
42
|
+
If None, the version name will be randomly generated.
|
|
43
|
+
model_signature (Optional[snowflake.ml.model.model_signature.ModelSignature]): The model signature to use
|
|
44
|
+
when logging the model. This is required if ``log_model`` is set to True.
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: When ``log_every_n_epochs`` is not a positive integer.
|
|
48
|
+
"""
|
|
26
49
|
self._experiment_tracking = experiment_tracking
|
|
27
50
|
self.log_model = log_model
|
|
28
51
|
self.log_metrics = log_metrics
|
|
29
52
|
self.log_params = log_params
|
|
30
|
-
if log_every_n_epochs
|
|
31
|
-
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
53
|
+
if not (utils.is_integer(log_every_n_epochs) and log_every_n_epochs > 0):
|
|
54
|
+
raise ValueError("`log_every_n_epochs` must be a positive integer.")
|
|
32
55
|
self.log_every_n_epochs = log_every_n_epochs
|
|
33
56
|
self.model_name = model_name
|
|
34
57
|
self.version_name = version_name
|
|
@@ -3,12 +3,16 @@ from warnings import warn
|
|
|
3
3
|
|
|
4
4
|
import lightgbm as lgb
|
|
5
5
|
|
|
6
|
+
from snowflake.ml.experiment import utils
|
|
7
|
+
|
|
6
8
|
if TYPE_CHECKING:
|
|
7
9
|
from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
|
|
8
10
|
from snowflake.ml.model.model_signature import ModelSignature
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
14
|
+
"""LightGBM callback for automatically logging to a Snowflake ML Experiment."""
|
|
15
|
+
|
|
12
16
|
def __init__(
|
|
13
17
|
self,
|
|
14
18
|
experiment_tracking: "ExperimentTracking",
|
|
@@ -20,12 +24,33 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
|
20
24
|
version_name: Optional[str] = None,
|
|
21
25
|
model_signature: Optional["ModelSignature"] = None,
|
|
22
26
|
) -> None:
|
|
27
|
+
"""
|
|
28
|
+
Creates a new LightGBM callback.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
experiment_tracking (snowflake.ml.experiment.ExperimentTracking): The Experiment Tracking instance
|
|
32
|
+
to use for logging.
|
|
33
|
+
log_model (bool): Whether to log the model at the end of training. Default is True.
|
|
34
|
+
log_metrics (bool): Whether to log metrics during training. Default is True.
|
|
35
|
+
log_params (bool): Whether to log model parameters at the start of training. Default is True.
|
|
36
|
+
log_every_n_epochs (int): Frequency with which to log metrics. Must be positive.
|
|
37
|
+
Default is 1, logging after every iteration.
|
|
38
|
+
model_name (Optional[str]): The model name to use when logging the model.
|
|
39
|
+
If None, the model name will be derived from the experiment name.
|
|
40
|
+
version_name (Optional[str]): The model version name to use when logging the model.
|
|
41
|
+
If None, the version name will be randomly generated.
|
|
42
|
+
model_signature (Optional[snowflake.ml.model.model_signature.ModelSignature]): The model signature to use
|
|
43
|
+
when logging the model. This is required if ``log_model`` is set to True.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
ValueError: When ``log_every_n_epochs`` is not a positive integer.
|
|
47
|
+
"""
|
|
23
48
|
self._experiment_tracking = experiment_tracking
|
|
24
49
|
self.log_model = log_model
|
|
25
50
|
self.log_metrics = log_metrics
|
|
26
51
|
self.log_params = log_params
|
|
27
|
-
if log_every_n_epochs
|
|
28
|
-
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
52
|
+
if not (utils.is_integer(log_every_n_epochs) and log_every_n_epochs > 0):
|
|
53
|
+
raise ValueError("`log_every_n_epochs` must be a positive integer.")
|
|
29
54
|
self.log_every_n_epochs = log_every_n_epochs
|
|
30
55
|
self.model_name = model_name
|
|
31
56
|
self.version_name = version_name
|
|
@@ -12,6 +12,8 @@ if TYPE_CHECKING:
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
15
|
+
"""XGBoost callback for automatically logging to a Snowflake ML Experiment."""
|
|
16
|
+
|
|
15
17
|
def __init__(
|
|
16
18
|
self,
|
|
17
19
|
experiment_tracking: "ExperimentTracking",
|
|
@@ -23,12 +25,33 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
|
23
25
|
version_name: Optional[str] = None,
|
|
24
26
|
model_signature: Optional["ModelSignature"] = None,
|
|
25
27
|
) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Initialize the callback.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
experiment_tracking (snowflake.ml.experiment.ExperimentTracking): The Experiment Tracking instance
|
|
33
|
+
to use for logging.
|
|
34
|
+
log_model (bool): Whether to log the model at the end of training. Default is True.
|
|
35
|
+
log_metrics (bool): Whether to log metrics during training. Default is True.
|
|
36
|
+
log_params (bool): Whether to log model parameters at the start of training. Default is True.
|
|
37
|
+
log_every_n_epochs (int): Frequency with which to log metrics. Must be positive.
|
|
38
|
+
Default is 1, logging after every iteration.
|
|
39
|
+
model_name (Optional[str]): The model name to use when logging the model.
|
|
40
|
+
If None, the model name will be derived from the experiment name.
|
|
41
|
+
version_name (Optional[str]): The model version name to use when logging the model.
|
|
42
|
+
If None, the version name will be randomly generated.
|
|
43
|
+
model_signature (Optional[snowflake.ml.model.model_signature.ModelSignature]): The model signature to use
|
|
44
|
+
when logging the model. This is required if ``log_model`` is set to True.
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: When ``log_every_n_epochs`` is not a positive integer.
|
|
48
|
+
"""
|
|
26
49
|
self._experiment_tracking = experiment_tracking
|
|
27
50
|
self.log_model = log_model
|
|
28
51
|
self.log_metrics = log_metrics
|
|
29
52
|
self.log_params = log_params
|
|
30
|
-
if log_every_n_epochs
|
|
31
|
-
raise ValueError("`log_every_n_epochs` must be positive.")
|
|
53
|
+
if not (utils.is_integer(log_every_n_epochs) and log_every_n_epochs > 0):
|
|
54
|
+
raise ValueError("`log_every_n_epochs` must be a positive integer.")
|
|
32
55
|
self.log_every_n_epochs = log_every_n_epochs
|
|
33
56
|
self.model_name = model_name
|
|
34
57
|
self.version_name = version_name
|