snowflake-ml-python 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. snowflake/ml/_internal/env_utils.py +16 -0
  2. snowflake/ml/_internal/platform_capabilities.py +36 -0
  3. snowflake/ml/_internal/telemetry.py +56 -7
  4. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  5. snowflake/ml/data/data_connector.py +103 -1
  6. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  7. snowflake/ml/experiment/_entities/run.py +15 -0
  8. snowflake/ml/experiment/callback/keras.py +25 -2
  9. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  10. snowflake/ml/experiment/callback/xgboost.py +25 -2
  11. snowflake/ml/experiment/experiment_tracking.py +123 -13
  12. snowflake/ml/experiment/utils.py +6 -0
  13. snowflake/ml/feature_store/access_manager.py +1 -0
  14. snowflake/ml/feature_store/feature_store.py +1 -1
  15. snowflake/ml/feature_store/feature_view.py +34 -24
  16. snowflake/ml/jobs/_interop/protocols.py +3 -0
  17. snowflake/ml/jobs/_utils/feature_flags.py +1 -0
  18. snowflake/ml/jobs/_utils/payload_utils.py +360 -357
  19. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  20. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  21. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  22. snowflake/ml/jobs/_utils/spec_utils.py +2 -406
  23. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  24. snowflake/ml/jobs/_utils/types.py +14 -7
  25. snowflake/ml/jobs/job.py +8 -9
  26. snowflake/ml/jobs/manager.py +64 -129
  27. snowflake/ml/model/_client/model/inference_engine_utils.py +8 -4
  28. snowflake/ml/model/_client/model/model_version_impl.py +109 -28
  29. snowflake/ml/model/_client/ops/model_ops.py +32 -6
  30. snowflake/ml/model/_client/ops/service_ops.py +9 -4
  31. snowflake/ml/model/_client/sql/service.py +69 -2
  32. snowflake/ml/model/_packager/model_handler.py +8 -2
  33. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  34. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  35. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  36. snowflake/ml/model/_signatures/core.py +305 -8
  37. snowflake/ml/model/_signatures/utils.py +13 -4
  38. snowflake/ml/model/compute_pool.py +2 -0
  39. snowflake/ml/model/models/huggingface.py +285 -0
  40. snowflake/ml/model/models/huggingface_pipeline.py +25 -215
  41. snowflake/ml/model/type_hints.py +5 -1
  42. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  43. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  44. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  45. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  46. snowflake/ml/utils/html_utils.py +67 -1
  47. snowflake/ml/version.py +1 -1
  48. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/METADATA +94 -7
  49. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/RECORD +52 -48
  50. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/WHEEL +0 -0
  51. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/licenses/LICENSE.txt +0 -0
  52. {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from snowflake.ml import version as snowml_version
16
16
  from snowflake.ml._internal import env as snowml_env, relax_version_strategy
17
17
  from snowflake.ml._internal.utils import query_result_checker
18
18
  from snowflake.snowpark import context, exceptions, session
19
+ from snowflake.snowpark._internal import utils as snowpark_utils
19
20
 
20
21
 
21
22
  class CONDA_OS(Enum):
@@ -38,6 +39,21 @@ SNOWPARK_ML_PKG_NAME = "snowflake-ml-python"
38
39
  SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
39
40
 
40
41
 
42
+ def get_execution_context() -> str:
43
+ """Detect execution context: EXTERNAL, SPCS, or SPROC.
44
+
45
+ Returns:
46
+ str: The execution context - "SPROC" if running in a stored procedure,
47
+ "SPCS" if running in SPCS ML runtime, "EXTERNAL" otherwise.
48
+ """
49
+ if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
50
+ return "SPROC"
51
+ elif snowml_env.IN_ML_RUNTIME:
52
+ return "SPCS"
53
+ else:
54
+ return "EXTERNAL"
55
+
56
+
41
57
  def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
42
58
  """Validate the input pip requirement string according to PEP 508.
43
59
 
@@ -18,6 +18,8 @@ logger = logging.getLogger(__name__)
18
18
  LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
19
19
  INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
20
20
  SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST = "SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST"
21
+ ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS = "ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS"
22
+ FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
21
23
 
22
24
 
23
25
  class PlatformCapabilities:
@@ -80,6 +82,12 @@ class PlatformCapabilities:
80
82
  def is_live_commit_enabled(self) -> bool:
81
83
  return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
82
84
 
85
+ def is_model_method_signature_parameters_enabled(self) -> bool:
86
+ return self._get_bool_feature(ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS, False)
87
+
88
+ def is_inference_autocapture_enabled(self) -> bool:
89
+ return self._is_feature_enabled(FEATURE_MODEL_INFERENCE_AUTOCAPTURE)
90
+
83
91
  @staticmethod
84
92
  def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
85
93
  try:
@@ -182,3 +190,31 @@ class PlatformCapabilities:
182
190
  f"current={current_version}, feature={feature_version}, enabled={result}"
183
191
  )
184
192
  return result
193
+
194
+ def _is_feature_enabled(self, feature_name: str) -> bool:
195
+ """Check if the feature parameter value belongs to enabled values.
196
+
197
+ Args:
198
+ feature_name: The name of the feature to retrieve.
199
+
200
+ Returns:
201
+ bool: True if the value is "ENABLED" or "ENABLED_PUBLIC_PREVIEW",
202
+ False if the value is "DISABLED", "DISABLED_PRIVATE_PREVIEW", or not set.
203
+
204
+ Raises:
205
+ ValueError: If the feature value is set but not one of the recognized values.
206
+ """
207
+ value = self.features.get(feature_name)
208
+ if value is None:
209
+ logger.debug(f"Feature {feature_name} not found.")
210
+ return False
211
+
212
+ if isinstance(value, str):
213
+ value_str = str(value)
214
+ if value_str.upper() in ["ENABLED", "ENABLED_PUBLIC_PREVIEW"]:
215
+ return True
216
+ elif value_str.upper() in ["DISABLED", "DISABLED_PRIVATE_PREVIEW"]:
217
+ return False
218
+ else:
219
+ raise ValueError(f"Invalid feature parameter value: {value} for feature {feature_name}")
220
+ raise ValueError(f"Invalid feature parameter string value: {value} for feature {feature_name}")
@@ -16,7 +16,7 @@ from typing_extensions import ParamSpec
16
16
  from snowflake import connector
17
17
  from snowflake.connector import connect, telemetry as connector_telemetry, time_util
18
18
  from snowflake.ml import version as snowml_version
19
- from snowflake.ml._internal import env
19
+ from snowflake.ml._internal import env, env_utils
20
20
  from snowflake.ml._internal.exceptions import (
21
21
  error_codes,
22
22
  exceptions as snowml_exceptions,
@@ -37,6 +37,22 @@ _CONNECTION_TYPES = {
37
37
  _Args = ParamSpec("_Args")
38
38
  _ReturnValue = TypeVar("_ReturnValue")
39
39
 
40
+ _conn: Optional[connector.SnowflakeConnection] = None
41
+
42
+
43
+ def clear_cached_conn() -> None:
44
+ """Clear the cached Snowflake connection. Primarily for testing purposes."""
45
+ global _conn
46
+ if _conn is not None and _conn.is_valid():
47
+ _conn.close()
48
+ _conn = None
49
+
50
+
51
+ def get_cached_conn() -> Optional[connector.SnowflakeConnection]:
52
+ """Get the cached Snowflake connection. Primarily for testing purposes."""
53
+ global _conn
54
+ return _conn
55
+
40
56
 
41
57
  def _get_login_token() -> Union[str, bytes]:
42
58
  with open("/snowflake/session/token") as f:
@@ -44,7 +60,11 @@ def _get_login_token() -> Union[str, bytes]:
44
60
 
45
61
 
46
62
  def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
47
- conn = None
63
+ global _conn
64
+ if _conn is not None and _conn.is_valid():
65
+ return _conn
66
+
67
+ conn: Optional[connector.SnowflakeConnection] = None
48
68
  if os.getenv("SNOWFLAKE_HOST") is not None and os.getenv("SNOWFLAKE_ACCOUNT") is not None:
49
69
  try:
50
70
  conn = connect(
@@ -66,6 +86,13 @@ def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
66
86
  # Failed to get an active session. No connection available.
67
87
  pass
68
88
 
89
+ # cache the connection if it's a SnowflakeConnection. there is a behavior at runtime where it could be a
90
+ # StoredProcConnection perhaps incorrect type hinting somewhere
91
+ if isinstance(conn, connector.SnowflakeConnection):
92
+ # if _conn was expired, we need to copy telemetry data to new connection
93
+ if _conn is not None and conn is not None:
94
+ conn._telemetry._log_batch.extend(_conn._telemetry._log_batch)
95
+ _conn = conn
69
96
  return conn
70
97
 
71
98
 
@@ -113,6 +140,13 @@ class TelemetryField(enum.Enum):
113
140
  FUNC_CAT_USAGE = "usage"
114
141
 
115
142
 
143
+ @enum.unique
144
+ class CustomTagKey(enum.Enum):
145
+ """Keys for custom tags in telemetry."""
146
+
147
+ EXECUTION_CONTEXT = "execution_context"
148
+
149
+
116
150
  class _TelemetrySourceType(enum.Enum):
117
151
  # Automatically inferred telemetry/statement parameters
118
152
  AUTO_TELEMETRY = "SNOWML_AUTO_TELEMETRY"
@@ -441,6 +475,7 @@ def send_api_usage_telemetry(
441
475
  sfqids_extractor: Optional[Callable[..., list[str]]] = None,
442
476
  subproject_extractor: Optional[Callable[[Any], str]] = None,
443
477
  custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
478
+ log_execution_context: bool = True,
444
479
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
445
480
  """
446
481
  Decorator that sends API usage telemetry and adds function usage statement parameters to the dataframe returned by
@@ -455,6 +490,8 @@ def send_api_usage_telemetry(
455
490
  sfqids_extractor: Extract sfqids from `self`.
456
491
  subproject_extractor: Extract subproject at runtime from `self`.
457
492
  custom_tags: Custom tags.
493
+ log_execution_context: If True, automatically detect and log execution context
494
+ (EXTERNAL, SPCS, or SPROC) in custom_tags.
458
495
 
459
496
  Returns:
460
497
  Decorator that sends function usage telemetry for any call to the decorated function.
@@ -495,6 +532,11 @@ def send_api_usage_telemetry(
495
532
  if subproject_extractor is not None:
496
533
  subproject_name = subproject_extractor(args[0])
497
534
 
535
+ # Add execution context if enabled
536
+ final_custom_tags = {**custom_tags} if custom_tags is not None else {}
537
+ if log_execution_context:
538
+ final_custom_tags[CustomTagKey.EXECUTION_CONTEXT.value] = env_utils.get_execution_context()
539
+
498
540
  statement_params = get_function_usage_statement_params(
499
541
  project=project,
500
542
  subproject=subproject_name,
@@ -502,7 +544,7 @@ def send_api_usage_telemetry(
502
544
  function_name=_get_full_func_name(func),
503
545
  function_parameters=params,
504
546
  api_calls=api_calls,
505
- custom_tags=custom_tags,
547
+ custom_tags=final_custom_tags,
506
548
  )
507
549
 
508
550
  def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: dict[str, Any]) -> _ReturnValue:
@@ -538,7 +580,10 @@ def send_api_usage_telemetry(
538
580
  if conn_attr_name:
539
581
  # raise AttributeError if conn attribute does not exist in `self`
540
582
  conn = operator.attrgetter(conn_attr_name)(args[0])
541
- if not isinstance(conn, _CONNECTION_TYPES.get(type(conn).__name__, connector.SnowflakeConnection)):
583
+ if not isinstance(
584
+ conn,
585
+ _CONNECTION_TYPES.get(type(conn).__name__, connector.SnowflakeConnection),
586
+ ):
542
587
  raise TypeError(
543
588
  f"Expected a conn object of type {' or '.join(_CONNECTION_TYPES.keys())} but got {type(conn)}"
544
589
  )
@@ -560,7 +605,7 @@ def send_api_usage_telemetry(
560
605
  func_params=params,
561
606
  api_calls=api_calls,
562
607
  sfqids=sfqids,
563
- custom_tags=custom_tags,
608
+ custom_tags=final_custom_tags,
564
609
  )
565
610
  try:
566
611
  return ctx.run(execute_func_with_statement_params)
@@ -571,7 +616,8 @@ def send_api_usage_telemetry(
571
616
  raise
572
617
  if isinstance(e, snowpark_exceptions.SnowparkClientException):
573
618
  me = snowml_exceptions.SnowflakeMLException(
574
- error_code=error_codes.INTERNAL_SNOWPARK_ERROR, original_exception=e
619
+ error_code=error_codes.INTERNAL_SNOWPARK_ERROR,
620
+ original_exception=e,
575
621
  )
576
622
  else:
577
623
  me = snowml_exceptions.SnowflakeMLException(
@@ -627,7 +673,10 @@ def _get_full_func_name(func: Callable[..., Any]) -> str:
627
673
 
628
674
 
629
675
  def _get_func_params(
630
- func: Callable[..., Any], func_params_to_log: Optional[Iterable[str]], args: Any, kwargs: Any
676
+ func: Callable[..., Any],
677
+ func_params_to_log: Optional[Iterable[str]],
678
+ args: Any,
679
+ kwargs: Any,
631
680
  ) -> dict[str, Any]:
632
681
  """
633
682
  Get function parameters.
@@ -1,6 +1,8 @@
1
+ import base64
1
2
  import collections
2
3
  import logging
3
4
  import os
5
+ import re
4
6
  import time
5
7
  from typing import TYPE_CHECKING, Any, Deque, Iterator, Optional, Sequence, Union
6
8
 
@@ -165,8 +167,71 @@ class ArrowIngestor(data_ingestor.DataIngestor, mixins.SerializableSessionMixin)
165
167
  # Re-shuffle input files on each iteration start
166
168
  if shuffle:
167
169
  np.random.shuffle(sources)
168
- pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
169
- return pa_dataset
170
+ try:
171
+ pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
172
+ return pa_dataset
173
+ except Exception as e:
174
+ self._tmp_debug_parquet_invalid(e, sources)
175
+
176
+ def _tmp_debug_parquet_invalid(self, e: Exception, sources: list[Any]) -> None:
177
+ # Attach rich debug info to help diagnose intermittent Parquet footer/magic byte errors
178
+ debug_parts: list[str] = []
179
+ debug_parts.append("SNOWML DEBUG: Failed to construct Arrow Dataset")
180
+ debug_parts.append(
181
+ "SNOWML DEBUG: " f"data_sources_count={len(self._data_sources)} " f"resolved_sources_count={len(sources)}"
182
+ )
183
+ # Try to include the exact file path mentioned by pyarrow, if present
184
+ error_text = str(e)
185
+ snow_paths: list[str] = []
186
+ try:
187
+ # Extract snow://... tokens possibly wrapped in quotes
188
+ for match in re.finditer(r'(snow://[^\s\'"]+)', error_text):
189
+ token = match.group(1).rstrip(").,;]")
190
+ snow_paths.append(token)
191
+ except Exception:
192
+ pass
193
+ fs = self._kwargs.get("filesystem")
194
+ if fs is not None:
195
+ # Always include a directory listing with sizes for context
196
+ try:
197
+ debug_parts.append("SNOWML DEBUG: Listing resolved sources with sizes:")
198
+ for s in sources:
199
+ try:
200
+ info = fs.info(s)
201
+ size = info.get("size", None)
202
+ md5 = info.get("md5", None)
203
+ debug_parts.append(f" - {s} size={size} md5={md5}")
204
+ except Exception as le:
205
+ debug_parts.append(f" - {s} info_failed={le}")
206
+ except Exception as le:
207
+ debug_parts.append(f"SNOWML DEBUG: listing sources failed: {le}")
208
+ # If pyarrow referenced a specific file, dump its full contents (base64) for inspection
209
+ for path in snow_paths[:1]: # usually only one path appears in the message
210
+ try:
211
+ info = fs.info(path)
212
+ size = info.get("size", None)
213
+ debug_parts.append(f"SNOWML DEBUG: Inspecting referenced file: {path} size={size}")
214
+ with fs.open(path, "rb") as f:
215
+ content = f.read()
216
+ magic_head = content[:4]
217
+ magic_tail = content[-4:] if content else b""
218
+ looks_like_parquet = (magic_head == b"PAR1") and (magic_tail == b"PAR1")
219
+ debug_parts.append(
220
+ "SNOWML DEBUG: "
221
+ f"file_magic_head={magic_head!r} "
222
+ f"file_magic_tail={magic_tail!r} "
223
+ f"parquet_magic_detected={looks_like_parquet}"
224
+ )
225
+ b64 = base64.b64encode(content).decode("ascii")
226
+ debug_parts.append("SNOWML DEBUG: file_content_base64 (entire file):")
227
+ debug_parts.append(b64)
228
+ except Exception as fe:
229
+ debug_parts.append(f"SNOWML DEBUG: failed to read referenced file {path}: {fe}")
230
+ else:
231
+ debug_parts.append("SNOWML DEBUG: No filesystem available; cannot inspect files")
232
+ debug_message = "\n".join(debug_parts)
233
+ # Re-raise with augmented message to surface in stacktrace
234
+ raise RuntimeError(f"{e}\n{debug_message}") from e
170
235
 
171
236
  def _get_batches_from_buffer(self, batch_size: int) -> dict[str, npt.NDArray[Any]]:
172
237
  """Generate new batches from the existing record batch buffer."""
@@ -1,5 +1,15 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Any, Generator, Optional, Sequence, TypeVar
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ Any,
5
+ Generator,
6
+ Literal,
7
+ Optional,
8
+ Sequence,
9
+ TypeVar,
10
+ Union,
11
+ overload,
12
+ )
3
13
 
4
14
  import numpy.typing as npt
5
15
  from typing_extensions import deprecated
@@ -11,6 +21,7 @@ from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
11
21
  from snowflake.snowpark import context as sp_context
12
22
 
13
23
  if TYPE_CHECKING:
24
+ import datasets as hf_datasets
14
25
  import pandas as pd
15
26
  import ray
16
27
  import tensorflow as tf
@@ -257,6 +268,97 @@ class DataConnector:
257
268
  except ImportError as e:
258
269
  raise ImportError("Ray is not installed, please install ray in your local environment.") from e
259
270
 
271
+ @overload
272
+ def to_huggingface_dataset(
273
+ self,
274
+ *,
275
+ streaming: Literal[False] = ...,
276
+ limit: Optional[int] = ...,
277
+ ) -> "hf_datasets.Dataset":
278
+ ...
279
+
280
+ @overload
281
+ def to_huggingface_dataset(
282
+ self,
283
+ *,
284
+ streaming: Literal[True],
285
+ limit: Optional[int] = ...,
286
+ batch_size: int = ...,
287
+ shuffle: bool = ...,
288
+ drop_last_batch: bool = ...,
289
+ ) -> "hf_datasets.IterableDataset":
290
+ ...
291
+
292
+ @telemetry.send_api_usage_telemetry(
293
+ project=_PROJECT,
294
+ subproject_extractor=lambda self: type(self).__name__,
295
+ func_params_to_log=["streaming", "limit", "batch_size", "shuffle", "drop_last_batch"],
296
+ )
297
+ def to_huggingface_dataset(
298
+ self,
299
+ *,
300
+ streaming: bool = False,
301
+ limit: Optional[int] = None,
302
+ batch_size: int = 1024,
303
+ shuffle: bool = False,
304
+ drop_last_batch: bool = False,
305
+ ) -> "Union[hf_datasets.Dataset, hf_datasets.IterableDataset]":
306
+ """Retrieve the Snowflake data as a HuggingFace Dataset.
307
+
308
+ Args:
309
+ streaming: If True, returns an IterableDataset that streams data in batches.
310
+ If False (default), returns an in-memory Dataset.
311
+ limit: Maximum number of rows to load. If None, loads all rows.
312
+ batch_size: Size of batches for internal data retrieval.
313
+ shuffle: Whether to shuffle the data. If True, files will be shuffled and rows
314
+ in each file will also be shuffled.
315
+ drop_last_batch: Whether to drop the last batch if it's smaller than batch_size.
316
+
317
+ Returns:
318
+ A HuggingFace Dataset (in-memory) or IterableDataset (streaming).
319
+ """
320
+ import datasets as hf_datasets
321
+
322
+ if streaming:
323
+ return self._to_huggingface_iterable_dataset(
324
+ limit=limit,
325
+ batch_size=batch_size,
326
+ shuffle=shuffle,
327
+ drop_last_batch=drop_last_batch,
328
+ )
329
+ else:
330
+ return hf_datasets.Dataset.from_pandas(self._ingestor.to_pandas(limit))
331
+
332
+ def _to_huggingface_iterable_dataset(
333
+ self,
334
+ *,
335
+ limit: Optional[int],
336
+ batch_size: int,
337
+ shuffle: bool,
338
+ drop_last_batch: bool,
339
+ ) -> "hf_datasets.IterableDataset":
340
+ """Create a HuggingFace IterableDataset that streams data in batches."""
341
+ import datasets as hf_datasets
342
+
343
+ def generator() -> Generator[dict[str, Any], None, None]:
344
+ rows_yielded = 0
345
+ for batch in self._ingestor.to_batches(batch_size, shuffle, drop_last_batch):
346
+ # Yield individual rows from each batch
347
+ num_rows = len(next(iter(batch.values())))
348
+ for i in range(num_rows):
349
+ if limit is not None and rows_yielded >= limit:
350
+ return
351
+ yield {k: v[i].item() if hasattr(v[i], "item") else v[i] for k, v in batch.items()}
352
+ rows_yielded += 1
353
+
354
+ result = hf_datasets.IterableDataset.from_generator(generator)
355
+ # In datasets >= 3.x, from_generator returns IterableDatasetDict
356
+ # We need to extract the IterableDataset from the dict
357
+ if hasattr(hf_datasets, "IterableDatasetDict") and isinstance(result, hf_datasets.IterableDatasetDict):
358
+ # Get the first (and only) dataset from the dict
359
+ result = next(iter(result.values()))
360
+ return result
361
+
260
362
 
261
363
  # Switch to use Runtime's Data Ingester if running in ML runtime
262
364
  # Fail silently if the data ingester is not found
@@ -41,8 +41,14 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
41
41
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
42
42
 
43
43
  @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
44
- def drop_experiment(self, *, experiment_name: sql_identifier.SqlIdentifier) -> None:
45
- experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
44
+ def drop_experiment(
45
+ self,
46
+ *,
47
+ database_name: sql_identifier.SqlIdentifier,
48
+ schema_name: sql_identifier.SqlIdentifier,
49
+ experiment_name: sql_identifier.SqlIdentifier,
50
+ ) -> None:
51
+ experiment_fqn = self.fully_qualified_object_name(database_name, schema_name, experiment_name)
46
52
  query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
47
53
  expected_rows=1, expected_cols=1
48
54
  ).validate()
@@ -1,4 +1,5 @@
1
1
  import types
2
+ import warnings
2
3
  from typing import TYPE_CHECKING, Optional
3
4
 
4
5
  from snowflake.ml._internal.utils import sql_identifier
@@ -7,6 +8,8 @@ from snowflake.ml.experiment import _experiment_info as experiment_info
7
8
  if TYPE_CHECKING:
8
9
  from snowflake.ml.experiment import experiment_tracking
9
10
 
11
+ METADATA_SIZE_WARNING_MESSAGE = "It is likely that no further metrics or parameters will be logged for this run."
12
+
10
13
 
11
14
  class Run:
12
15
  def __init__(
@@ -20,6 +23,9 @@ class Run:
20
23
  self.experiment_name = experiment_name
21
24
  self.name = run_name
22
25
 
26
+ # Whether we've already shown the user a warning about exceeding the run metadata size limit.
27
+ self._warned_about_metadata_size = False
28
+
23
29
  self._patcher = experiment_info.ExperimentInfoPatcher(
24
30
  experiment_info=self._get_experiment_info(),
25
31
  )
@@ -45,3 +51,12 @@ class Run:
45
51
  ),
46
52
  run_name=self.name.identifier(),
47
53
  )
54
+
55
+ def _warn_about_run_metadata_size(self, sql_error_msg: str) -> None:
56
+ if not self._warned_about_metadata_size:
57
+ warnings.warn(
58
+ f"{sql_error_msg}. {METADATA_SIZE_WARNING_MESSAGE}",
59
+ RuntimeWarning,
60
+ stacklevel=2,
61
+ )
62
+ self._warned_about_metadata_size = True
@@ -12,6 +12,8 @@ if TYPE_CHECKING:
12
12
 
13
13
 
14
14
  class SnowflakeKerasCallback(keras.callbacks.Callback):
15
+ """Keras callback for automatically logging to a Snowflake ML Experiment."""
16
+
15
17
  def __init__(
16
18
  self,
17
19
  experiment_tracking: "ExperimentTracking",
@@ -23,12 +25,33 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
23
25
  version_name: Optional[str] = None,
24
26
  model_signature: Optional["ModelSignature"] = None,
25
27
  ) -> None:
28
+ """
29
+ Creates a new Keras callback.
30
+
31
+ Args:
32
+ experiment_tracking (snowflake.ml.experiment.ExperimentTracking): The Experiment Tracking instance
33
+ to use for logging.
34
+ log_model (bool): Whether to log the model at the end of training. Default is True.
35
+ log_metrics (bool): Whether to log metrics during training. Default is True.
36
+ log_params (bool): Whether to log model parameters at the start of training. Default is True.
37
+ log_every_n_epochs (int): Frequency with which to log metrics. Must be positive.
38
+ Default is 1, logging after every epoch.
39
+ model_name (Optional[str]): The model name to use when logging the model.
40
+ If None, the model name will be derived from the experiment name.
41
+ version_name (Optional[str]): The model version name to use when logging the model.
42
+ If None, the version name will be randomly generated.
43
+ model_signature (Optional[snowflake.ml.model.model_signature.ModelSignature]): The model signature to use
44
+ when logging the model. This is required if ``log_model`` is set to True.
45
+
46
+ Raises:
47
+ ValueError: When ``log_every_n_epochs`` is not a positive integer.
48
+ """
26
49
  self._experiment_tracking = experiment_tracking
27
50
  self.log_model = log_model
28
51
  self.log_metrics = log_metrics
29
52
  self.log_params = log_params
30
- if log_every_n_epochs < 1:
31
- raise ValueError("`log_every_n_epochs` must be positive.")
53
+ if not (utils.is_integer(log_every_n_epochs) and log_every_n_epochs > 0):
54
+ raise ValueError("`log_every_n_epochs` must be a positive integer.")
32
55
  self.log_every_n_epochs = log_every_n_epochs
33
56
  self.model_name = model_name
34
57
  self.version_name = version_name
@@ -3,12 +3,16 @@ from warnings import warn
3
3
 
4
4
  import lightgbm as lgb
5
5
 
6
+ from snowflake.ml.experiment import utils
7
+
6
8
  if TYPE_CHECKING:
7
9
  from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
8
10
  from snowflake.ml.model.model_signature import ModelSignature
9
11
 
10
12
 
11
13
  class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
14
+ """LightGBM callback for automatically logging to a Snowflake ML Experiment."""
15
+
12
16
  def __init__(
13
17
  self,
14
18
  experiment_tracking: "ExperimentTracking",
@@ -20,12 +24,33 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
20
24
  version_name: Optional[str] = None,
21
25
  model_signature: Optional["ModelSignature"] = None,
22
26
  ) -> None:
27
+ """
28
+ Creates a new LightGBM callback.
29
+
30
+ Args:
31
+ experiment_tracking (snowflake.ml.experiment.ExperimentTracking): The Experiment Tracking instance
32
+ to use for logging.
33
+ log_model (bool): Whether to log the model at the end of training. Default is True.
34
+ log_metrics (bool): Whether to log metrics during training. Default is True.
35
+ log_params (bool): Whether to log model parameters at the start of training. Default is True.
36
+ log_every_n_epochs (int): Frequency with which to log metrics. Must be positive.
37
+ Default is 1, logging after every iteration.
38
+ model_name (Optional[str]): The model name to use when logging the model.
39
+ If None, the model name will be derived from the experiment name.
40
+ version_name (Optional[str]): The model version name to use when logging the model.
41
+ If None, the version name will be randomly generated.
42
+ model_signature (Optional[snowflake.ml.model.model_signature.ModelSignature]): The model signature to use
43
+ when logging the model. This is required if ``log_model`` is set to True.
44
+
45
+ Raises:
46
+ ValueError: When ``log_every_n_epochs`` is not a positive integer.
47
+ """
23
48
  self._experiment_tracking = experiment_tracking
24
49
  self.log_model = log_model
25
50
  self.log_metrics = log_metrics
26
51
  self.log_params = log_params
27
- if log_every_n_epochs < 1:
28
- raise ValueError("`log_every_n_epochs` must be positive.")
52
+ if not (utils.is_integer(log_every_n_epochs) and log_every_n_epochs > 0):
53
+ raise ValueError("`log_every_n_epochs` must be a positive integer.")
29
54
  self.log_every_n_epochs = log_every_n_epochs
30
55
  self.model_name = model_name
31
56
  self.version_name = version_name
@@ -12,6 +12,8 @@ if TYPE_CHECKING:
12
12
 
13
13
 
14
14
  class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
15
+ """XGBoost callback for automatically logging to a Snowflake ML Experiment."""
16
+
15
17
  def __init__(
16
18
  self,
17
19
  experiment_tracking: "ExperimentTracking",
@@ -23,12 +25,33 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
23
25
  version_name: Optional[str] = None,
24
26
  model_signature: Optional["ModelSignature"] = None,
25
27
  ) -> None:
28
+ """
29
+ Initialize the callback.
30
+
31
+ Args:
32
+ experiment_tracking (snowflake.ml.experiment.ExperimentTracking): The Experiment Tracking instance
33
+ to use for logging.
34
+ log_model (bool): Whether to log the model at the end of training. Default is True.
35
+ log_metrics (bool): Whether to log metrics during training. Default is True.
36
+ log_params (bool): Whether to log model parameters at the start of training. Default is True.
37
+ log_every_n_epochs (int): Frequency with which to log metrics. Must be positive.
38
+ Default is 1, logging after every iteration.
39
+ model_name (Optional[str]): The model name to use when logging the model.
40
+ If None, the model name will be derived from the experiment name.
41
+ version_name (Optional[str]): The model version name to use when logging the model.
42
+ If None, the version name will be randomly generated.
43
+ model_signature (Optional[snowflake.ml.model.model_signature.ModelSignature]): The model signature to use
44
+ when logging the model. This is required if ``log_model`` is set to True.
45
+
46
+ Raises:
47
+ ValueError: When ``log_every_n_epochs`` is not a positive integer.
48
+ """
26
49
  self._experiment_tracking = experiment_tracking
27
50
  self.log_model = log_model
28
51
  self.log_metrics = log_metrics
29
52
  self.log_params = log_params
30
- if log_every_n_epochs < 1:
31
- raise ValueError("`log_every_n_epochs` must be positive.")
53
+ if not (utils.is_integer(log_every_n_epochs) and log_every_n_epochs > 0):
54
+ raise ValueError("`log_every_n_epochs` must be a positive integer.")
32
55
  self.log_every_n_epochs = log_every_n_epochs
33
56
  self.model_name = model_name
34
57
  self.version_name = version_name