snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +36 -0
  2. snowflake/ml/_internal/utils/url.py +42 -0
  3. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  4. snowflake/ml/data/data_connector.py +103 -1
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  6. snowflake/ml/experiment/callback/__init__.py +0 -0
  7. snowflake/ml/experiment/callback/keras.py +25 -2
  8. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  9. snowflake/ml/experiment/callback/xgboost.py +25 -2
  10. snowflake/ml/experiment/experiment_tracking.py +93 -3
  11. snowflake/ml/experiment/utils.py +6 -0
  12. snowflake/ml/feature_store/feature_view.py +34 -24
  13. snowflake/ml/jobs/_interop/protocols.py +3 -0
  14. snowflake/ml/jobs/_utils/constants.py +1 -0
  15. snowflake/ml/jobs/_utils/payload_utils.py +354 -356
  16. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  17. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  18. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  19. snowflake/ml/jobs/_utils/spec_utils.py +1 -445
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  21. snowflake/ml/jobs/_utils/types.py +14 -7
  22. snowflake/ml/jobs/job.py +2 -8
  23. snowflake/ml/jobs/manager.py +57 -135
  24. snowflake/ml/lineage/lineage_node.py +1 -1
  25. snowflake/ml/model/__init__.py +6 -0
  26. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  27. snowflake/ml/model/_client/model/model_version_impl.py +130 -14
  28. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  29. snowflake/ml/model/_client/ops/model_ops.py +93 -8
  30. snowflake/ml/model/_client/ops/service_ops.py +32 -52
  31. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
  34. snowflake/ml/model/_client/sql/model_version.py +30 -6
  35. snowflake/ml/model/_client/sql/service.py +94 -5
  36. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
  39. snowflake/ml/model/_packager/model_handler.py +8 -2
  40. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  41. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  42. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  45. snowflake/ml/model/_packager/model_packager.py +1 -1
  46. snowflake/ml/model/_signatures/core.py +390 -8
  47. snowflake/ml/model/_signatures/utils.py +13 -4
  48. snowflake/ml/model/code_path.py +104 -0
  49. snowflake/ml/model/compute_pool.py +2 -0
  50. snowflake/ml/model/custom_model.py +55 -13
  51. snowflake/ml/model/model_signature.py +13 -1
  52. snowflake/ml/model/models/huggingface.py +285 -0
  53. snowflake/ml/model/models/huggingface_pipeline.py +19 -208
  54. snowflake/ml/model/type_hints.py +7 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  56. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  57. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  58. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  59. snowflake/ml/registry/_manager/model_manager.py +230 -15
  60. snowflake/ml/registry/registry.py +4 -4
  61. snowflake/ml/utils/html_utils.py +67 -1
  62. snowflake/ml/version.py +1 -1
  63. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
  64. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
  65. snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
  66. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
  67. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
  68. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,8 @@ logger = logging.getLogger(__name__)
18
18
  LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
19
19
  INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
20
20
  SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST = "SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST"
21
+ ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS = "ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS"
22
+ FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
21
23
 
22
24
 
23
25
  class PlatformCapabilities:
@@ -80,6 +82,12 @@ class PlatformCapabilities:
80
82
  def is_live_commit_enabled(self) -> bool:
81
83
  return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
82
84
 
85
+ def is_model_method_signature_parameters_enabled(self) -> bool:
86
+ return self._get_bool_feature(ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS, False)
87
+
88
+ def is_inference_autocapture_enabled(self) -> bool:
89
+ return self._is_feature_enabled(FEATURE_MODEL_INFERENCE_AUTOCAPTURE)
90
+
83
91
  @staticmethod
84
92
  def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
85
93
  try:
@@ -182,3 +190,31 @@ class PlatformCapabilities:
182
190
  f"current={current_version}, feature={feature_version}, enabled={result}"
183
191
  )
184
192
  return result
193
+
194
+ def _is_feature_enabled(self, feature_name: str) -> bool:
195
+ """Check if the feature parameter value belongs to enabled values.
196
+
197
+ Args:
198
+ feature_name: The name of the feature to retrieve.
199
+
200
+ Returns:
201
+ bool: True if the value is "ENABLED" or "ENABLED_PUBLIC_PREVIEW",
202
+ False if the value is "DISABLED", "DISABLED_PRIVATE_PREVIEW", or not set.
203
+
204
+ Raises:
205
+ ValueError: If the feature value is set but not one of the recognized values.
206
+ """
207
+ value = self.features.get(feature_name)
208
+ if value is None:
209
+ logger.debug(f"Feature {feature_name} not found.")
210
+ return False
211
+
212
+ if isinstance(value, str):
213
+ value_str = str(value)
214
+ if value_str.upper() in ["ENABLED", "ENABLED_PUBLIC_PREVIEW"]:
215
+ return True
216
+ elif value_str.upper() in ["DISABLED", "DISABLED_PRIVATE_PREVIEW"]:
217
+ return False
218
+ else:
219
+ raise ValueError(f"Invalid feature parameter value: {value} for feature {feature_name}")
220
+ raise ValueError(f"Invalid feature parameter string value: {value} for feature {feature_name}")
@@ -0,0 +1,42 @@
1
+ from urllib.parse import urlunparse
2
+
3
+ from snowflake import snowpark as snowpark
4
+
5
+ JOB_URL_PREFIX = "#/compute/job/"
6
+ SERVICE_URL_PREFIX = "#/compute/service/"
7
+
8
+
9
+ def get_snowflake_url(
10
+ session: snowpark.Session,
11
+ url_path: str,
12
+ params: str = "",
13
+ query: str = "",
14
+ fragment: str = "",
15
+ ) -> str:
16
+ """Construct a Snowflake URL from session connection details and URL components.
17
+
18
+ Args:
19
+ session: The Snowpark session containing connection details.
20
+ url_path: The path component of the URL (e.g., "/compute/job/123").
21
+ params: Optional parameters for the URL (RFC 1808). Defaults to "".
22
+ query: Optional query string for the URL. Defaults to "".
23
+ fragment: Optional fragment identifier for the URL (e.g., "#section"). Defaults to "".
24
+
25
+ Returns:
26
+ A fully constructed Snowflake URL string with scheme, host, and specified components.
27
+ """
28
+ scheme = "https"
29
+ if hasattr(session.connection, "scheme"):
30
+ scheme = session.connection.scheme
31
+ host = session.connection.host
32
+
33
+ return urlunparse(
34
+ (
35
+ scheme,
36
+ host,
37
+ url_path,
38
+ params,
39
+ query,
40
+ fragment,
41
+ )
42
+ )
@@ -1,6 +1,8 @@
1
+ import base64
1
2
  import collections
2
3
  import logging
3
4
  import os
5
+ import re
4
6
  import time
5
7
  from typing import TYPE_CHECKING, Any, Deque, Iterator, Optional, Sequence, Union
6
8
 
@@ -165,8 +167,71 @@ class ArrowIngestor(data_ingestor.DataIngestor, mixins.SerializableSessionMixin)
165
167
  # Re-shuffle input files on each iteration start
166
168
  if shuffle:
167
169
  np.random.shuffle(sources)
168
- pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
169
- return pa_dataset
170
+ try:
171
+ pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
172
+ return pa_dataset
173
+ except Exception as e:
174
+ self._tmp_debug_parquet_invalid(e, sources)
175
+
176
+ def _tmp_debug_parquet_invalid(self, e: Exception, sources: list[Any]) -> None:
177
+ # Attach rich debug info to help diagnose intermittent Parquet footer/magic byte errors
178
+ debug_parts: list[str] = []
179
+ debug_parts.append("SNOWML DEBUG: Failed to construct Arrow Dataset")
180
+ debug_parts.append(
181
+ "SNOWML DEBUG: " f"data_sources_count={len(self._data_sources)} " f"resolved_sources_count={len(sources)}"
182
+ )
183
+ # Try to include the exact file path mentioned by pyarrow, if present
184
+ error_text = str(e)
185
+ snow_paths: list[str] = []
186
+ try:
187
+ # Extract snow://... tokens possibly wrapped in quotes
188
+ for match in re.finditer(r'(snow://[^\s\'"]+)', error_text):
189
+ token = match.group(1).rstrip(").,;]")
190
+ snow_paths.append(token)
191
+ except Exception:
192
+ pass
193
+ fs = self._kwargs.get("filesystem")
194
+ if fs is not None:
195
+ # Always include a directory listing with sizes for context
196
+ try:
197
+ debug_parts.append("SNOWML DEBUG: Listing resolved sources with sizes:")
198
+ for s in sources:
199
+ try:
200
+ info = fs.info(s)
201
+ size = info.get("size", None)
202
+ md5 = info.get("md5", None)
203
+ debug_parts.append(f" - {s} size={size} md5={md5}")
204
+ except Exception as le:
205
+ debug_parts.append(f" - {s} info_failed={le}")
206
+ except Exception as le:
207
+ debug_parts.append(f"SNOWML DEBUG: listing sources failed: {le}")
208
+ # If pyarrow referenced a specific file, dump its full contents (base64) for inspection
209
+ for path in snow_paths[:1]: # usually only one path appears in the message
210
+ try:
211
+ info = fs.info(path)
212
+ size = info.get("size", None)
213
+ debug_parts.append(f"SNOWML DEBUG: Inspecting referenced file: {path} size={size}")
214
+ with fs.open(path, "rb") as f:
215
+ content = f.read()
216
+ magic_head = content[:4]
217
+ magic_tail = content[-4:] if content else b""
218
+ looks_like_parquet = (magic_head == b"PAR1") and (magic_tail == b"PAR1")
219
+ debug_parts.append(
220
+ "SNOWML DEBUG: "
221
+ f"file_magic_head={magic_head!r} "
222
+ f"file_magic_tail={magic_tail!r} "
223
+ f"parquet_magic_detected={looks_like_parquet}"
224
+ )
225
+ b64 = base64.b64encode(content).decode("ascii")
226
+ debug_parts.append("SNOWML DEBUG: file_content_base64 (entire file):")
227
+ debug_parts.append(b64)
228
+ except Exception as fe:
229
+ debug_parts.append(f"SNOWML DEBUG: failed to read referenced file {path}: {fe}")
230
+ else:
231
+ debug_parts.append("SNOWML DEBUG: No filesystem available; cannot inspect files")
232
+ debug_message = "\n".join(debug_parts)
233
+ # Re-raise with augmented message to surface in stacktrace
234
+ raise RuntimeError(f"{e}\n{debug_message}") from e
170
235
 
171
236
  def _get_batches_from_buffer(self, batch_size: int) -> dict[str, npt.NDArray[Any]]:
172
237
  """Generate new batches from the existing record batch buffer."""
@@ -1,5 +1,15 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Any, Generator, Optional, Sequence, TypeVar
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ Any,
5
+ Generator,
6
+ Literal,
7
+ Optional,
8
+ Sequence,
9
+ TypeVar,
10
+ Union,
11
+ overload,
12
+ )
3
13
 
4
14
  import numpy.typing as npt
5
15
  from typing_extensions import deprecated
@@ -11,6 +21,7 @@ from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
11
21
  from snowflake.snowpark import context as sp_context
12
22
 
13
23
  if TYPE_CHECKING:
24
+ import datasets as hf_datasets
14
25
  import pandas as pd
15
26
  import ray
16
27
  import tensorflow as tf
@@ -257,6 +268,97 @@ class DataConnector:
257
268
  except ImportError as e:
258
269
  raise ImportError("Ray is not installed, please install ray in your local environment.") from e
259
270
 
271
+ @overload
272
+ def to_huggingface_dataset(
273
+ self,
274
+ *,
275
+ streaming: Literal[False] = ...,
276
+ limit: Optional[int] = ...,
277
+ ) -> "hf_datasets.Dataset":
278
+ ...
279
+
280
+ @overload
281
+ def to_huggingface_dataset(
282
+ self,
283
+ *,
284
+ streaming: Literal[True],
285
+ limit: Optional[int] = ...,
286
+ batch_size: int = ...,
287
+ shuffle: bool = ...,
288
+ drop_last_batch: bool = ...,
289
+ ) -> "hf_datasets.IterableDataset":
290
+ ...
291
+
292
+ @telemetry.send_api_usage_telemetry(
293
+ project=_PROJECT,
294
+ subproject_extractor=lambda self: type(self).__name__,
295
+ func_params_to_log=["streaming", "limit", "batch_size", "shuffle", "drop_last_batch"],
296
+ )
297
+ def to_huggingface_dataset(
298
+ self,
299
+ *,
300
+ streaming: bool = False,
301
+ limit: Optional[int] = None,
302
+ batch_size: int = 1024,
303
+ shuffle: bool = False,
304
+ drop_last_batch: bool = False,
305
+ ) -> "Union[hf_datasets.Dataset, hf_datasets.IterableDataset]":
306
+ """Retrieve the Snowflake data as a HuggingFace Dataset.
307
+
308
+ Args:
309
+ streaming: If True, returns an IterableDataset that streams data in batches.
310
+ If False (default), returns an in-memory Dataset.
311
+ limit: Maximum number of rows to load. If None, loads all rows.
312
+ batch_size: Size of batches for internal data retrieval.
313
+ shuffle: Whether to shuffle the data. If True, files will be shuffled and rows
314
+ in each file will also be shuffled.
315
+ drop_last_batch: Whether to drop the last batch if it's smaller than batch_size.
316
+
317
+ Returns:
318
+ A HuggingFace Dataset (in-memory) or IterableDataset (streaming).
319
+ """
320
+ import datasets as hf_datasets
321
+
322
+ if streaming:
323
+ return self._to_huggingface_iterable_dataset(
324
+ limit=limit,
325
+ batch_size=batch_size,
326
+ shuffle=shuffle,
327
+ drop_last_batch=drop_last_batch,
328
+ )
329
+ else:
330
+ return hf_datasets.Dataset.from_pandas(self._ingestor.to_pandas(limit))
331
+
332
+ def _to_huggingface_iterable_dataset(
333
+ self,
334
+ *,
335
+ limit: Optional[int],
336
+ batch_size: int,
337
+ shuffle: bool,
338
+ drop_last_batch: bool,
339
+ ) -> "hf_datasets.IterableDataset":
340
+ """Create a HuggingFace IterableDataset that streams data in batches."""
341
+ import datasets as hf_datasets
342
+
343
+ def generator() -> Generator[dict[str, Any], None, None]:
344
+ rows_yielded = 0
345
+ for batch in self._ingestor.to_batches(batch_size, shuffle, drop_last_batch):
346
+ # Yield individual rows from each batch
347
+ num_rows = len(next(iter(batch.values())))
348
+ for i in range(num_rows):
349
+ if limit is not None and rows_yielded >= limit:
350
+ return
351
+ yield {k: v[i].item() if hasattr(v[i], "item") else v[i] for k, v in batch.items()}
352
+ rows_yielded += 1
353
+
354
+ result = hf_datasets.IterableDataset.from_generator(generator)
355
+ # In datasets >= 3.x, from_generator returns IterableDatasetDict
356
+ # We need to extract the IterableDataset from the dict
357
+ if hasattr(hf_datasets, "IterableDatasetDict") and isinstance(result, hf_datasets.IterableDatasetDict):
358
+ # Get the first (and only) dataset from the dict
359
+ result = next(iter(result.values()))
360
+ return result
361
+
260
362
 
261
363
  # Switch to use Runtime's Data Ingester if running in ML runtime
262
364
  # Fail silently if the data ingester is not found
@@ -41,8 +41,14 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
41
41
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
42
42
 
43
43
  @telemetry.send_api_usage_telemetry(project=telemetry.TelemetryProject.EXPERIMENT_TRACKING.value)
44
- def drop_experiment(self, *, experiment_name: sql_identifier.SqlIdentifier) -> None:
45
- experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
44
+ def drop_experiment(
45
+ self,
46
+ *,
47
+ database_name: sql_identifier.SqlIdentifier,
48
+ schema_name: sql_identifier.SqlIdentifier,
49
+ experiment_name: sql_identifier.SqlIdentifier,
50
+ ) -> None:
51
+ experiment_fqn = self.fully_qualified_object_name(database_name, schema_name, experiment_name)
46
52
  query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
47
53
  expected_rows=1, expected_cols=1
48
54
  ).validate()
File without changes
@@ -12,6 +12,8 @@ if TYPE_CHECKING:
12
12
 
13
13
 
14
14
  class SnowflakeKerasCallback(keras.callbacks.Callback):
15
+ """Keras callback for automatically logging to a Snowflake ML Experiment."""
16
+
15
17
  def __init__(
16
18
  self,
17
19
  experiment_tracking: "ExperimentTracking",
@@ -23,12 +25,33 @@ class SnowflakeKerasCallback(keras.callbacks.Callback):
23
25
  version_name: Optional[str] = None,
24
26
  model_signature: Optional["ModelSignature"] = None,
25
27
  ) -> None:
28
+ """
29
+ Creates a new Keras callback.
30
+
31
+ Args:
32
+ experiment_tracking (snowflake.ml.experiment.ExperimentTracking): The Experiment Tracking instance
33
+ to use for logging.
34
+ log_model (bool): Whether to log the model at the end of training. Default is True.
35
+ log_metrics (bool): Whether to log metrics during training. Default is True.
36
+ log_params (bool): Whether to log model parameters at the start of training. Default is True.
37
+ log_every_n_epochs (int): Frequency with which to log metrics. Must be positive.
38
+ Default is 1, logging after every epoch.
39
+ model_name (Optional[str]): The model name to use when logging the model.
40
+ If None, the model name will be derived from the experiment name.
41
+ version_name (Optional[str]): The model version name to use when logging the model.
42
+ If None, the version name will be randomly generated.
43
+ model_signature (Optional[snowflake.ml.model.model_signature.ModelSignature]): The model signature to use
44
+ when logging the model. This is required if ``log_model`` is set to True.
45
+
46
+ Raises:
47
+ ValueError: When ``log_every_n_epochs`` is not a positive integer.
48
+ """
26
49
  self._experiment_tracking = experiment_tracking
27
50
  self.log_model = log_model
28
51
  self.log_metrics = log_metrics
29
52
  self.log_params = log_params
30
- if log_every_n_epochs < 1:
31
- raise ValueError("`log_every_n_epochs` must be positive.")
53
+ if not (utils.is_integer(log_every_n_epochs) and log_every_n_epochs > 0):
54
+ raise ValueError("`log_every_n_epochs` must be a positive integer.")
32
55
  self.log_every_n_epochs = log_every_n_epochs
33
56
  self.model_name = model_name
34
57
  self.version_name = version_name
@@ -3,12 +3,16 @@ from warnings import warn
3
3
 
4
4
  import lightgbm as lgb
5
5
 
6
+ from snowflake.ml.experiment import utils
7
+
6
8
  if TYPE_CHECKING:
7
9
  from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
8
10
  from snowflake.ml.model.model_signature import ModelSignature
9
11
 
10
12
 
11
13
  class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
14
+ """LightGBM callback for automatically logging to a Snowflake ML Experiment."""
15
+
12
16
  def __init__(
13
17
  self,
14
18
  experiment_tracking: "ExperimentTracking",
@@ -20,12 +24,33 @@ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
20
24
  version_name: Optional[str] = None,
21
25
  model_signature: Optional["ModelSignature"] = None,
22
26
  ) -> None:
27
+ """
28
+ Creates a new LightGBM callback.
29
+
30
+ Args:
31
+ experiment_tracking (snowflake.ml.experiment.ExperimentTracking): The Experiment Tracking instance
32
+ to use for logging.
33
+ log_model (bool): Whether to log the model at the end of training. Default is True.
34
+ log_metrics (bool): Whether to log metrics during training. Default is True.
35
+ log_params (bool): Whether to log model parameters at the start of training. Default is True.
36
+ log_every_n_epochs (int): Frequency with which to log metrics. Must be positive.
37
+ Default is 1, logging after every iteration.
38
+ model_name (Optional[str]): The model name to use when logging the model.
39
+ If None, the model name will be derived from the experiment name.
40
+ version_name (Optional[str]): The model version name to use when logging the model.
41
+ If None, the version name will be randomly generated.
42
+ model_signature (Optional[snowflake.ml.model.model_signature.ModelSignature]): The model signature to use
43
+ when logging the model. This is required if ``log_model`` is set to True.
44
+
45
+ Raises:
46
+ ValueError: When ``log_every_n_epochs`` is not a positive integer.
47
+ """
23
48
  self._experiment_tracking = experiment_tracking
24
49
  self.log_model = log_model
25
50
  self.log_metrics = log_metrics
26
51
  self.log_params = log_params
27
- if log_every_n_epochs < 1:
28
- raise ValueError("`log_every_n_epochs` must be positive.")
52
+ if not (utils.is_integer(log_every_n_epochs) and log_every_n_epochs > 0):
53
+ raise ValueError("`log_every_n_epochs` must be a positive integer.")
29
54
  self.log_every_n_epochs = log_every_n_epochs
30
55
  self.model_name = model_name
31
56
  self.version_name = version_name
@@ -12,6 +12,8 @@ if TYPE_CHECKING:
12
12
 
13
13
 
14
14
  class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
15
+ """XGBoost callback for automatically logging to a Snowflake ML Experiment."""
16
+
15
17
  def __init__(
16
18
  self,
17
19
  experiment_tracking: "ExperimentTracking",
@@ -23,12 +25,33 @@ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
23
25
  version_name: Optional[str] = None,
24
26
  model_signature: Optional["ModelSignature"] = None,
25
27
  ) -> None:
28
+ """
29
+ Initialize the callback.
30
+
31
+ Args:
32
+ experiment_tracking (snowflake.ml.experiment.ExperimentTracking): The Experiment Tracking instance
33
+ to use for logging.
34
+ log_model (bool): Whether to log the model at the end of training. Default is True.
35
+ log_metrics (bool): Whether to log metrics during training. Default is True.
36
+ log_params (bool): Whether to log model parameters at the start of training. Default is True.
37
+ log_every_n_epochs (int): Frequency with which to log metrics. Must be positive.
38
+ Default is 1, logging after every iteration.
39
+ model_name (Optional[str]): The model name to use when logging the model.
40
+ If None, the model name will be derived from the experiment name.
41
+ version_name (Optional[str]): The model version name to use when logging the model.
42
+ If None, the version name will be randomly generated.
43
+ model_signature (Optional[snowflake.ml.model.model_signature.ModelSignature]): The model signature to use
44
+ when logging the model. This is required if ``log_model`` is set to True.
45
+
46
+ Raises:
47
+ ValueError: When ``log_every_n_epochs`` is not a positive integer.
48
+ """
26
49
  self._experiment_tracking = experiment_tracking
27
50
  self.log_model = log_model
28
51
  self.log_metrics = log_metrics
29
52
  self.log_params = log_params
30
- if log_every_n_epochs < 1:
31
- raise ValueError("`log_every_n_epochs` must be positive.")
53
+ if not (utils.is_integer(log_every_n_epochs) and log_every_n_epochs > 0):
54
+ raise ValueError("`log_every_n_epochs` must be a positive integer.")
32
55
  self.log_every_n_epochs = log_every_n_epochs
33
56
  self.model_name = model_name
34
57
  self.version_name = version_name
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import json
3
3
  import sys
4
+ import warnings
4
5
  from typing import Any, Optional, Union
5
6
  from urllib.parse import quote
6
7
 
@@ -27,6 +28,13 @@ class ExperimentTracking:
27
28
  Class to manage experiments in Snowflake.
28
29
  """
29
30
 
31
+ _instance = None
32
+
33
+ def __new__(cls, *args: Any, **kwargs: Any) -> "ExperimentTracking":
34
+ if cls._instance is None:
35
+ cls._instance = super().__new__(cls)
36
+ return cls._instance
37
+
30
38
  def __init__(
31
39
  self,
32
40
  session: snowpark.Session,
@@ -36,6 +44,7 @@ class ExperimentTracking:
36
44
  ) -> None:
37
45
  """
38
46
  Initializes experiment tracking within a pre-created schema.
47
+ This is a singleton class, so if an instance already exists, it will not reinitialize.
39
48
 
40
49
  Args:
41
50
  session: The Snowpark Session to connect with Snowflake.
@@ -47,6 +56,21 @@ class ExperimentTracking:
47
56
  Raises:
48
57
  ValueError: If no database is provided and no active database exists in the session.
49
58
  """
59
+ if hasattr(self, "_initialized"):
60
+ warnings.warn(
61
+ "ExperimentTracking is a singleton class. Reusing the existing instance, which has the setting:\n"
62
+ f" Database: {self._database_name}, Schema: {self._schema_name}\n"
63
+ "To change the database or schema, use the database_name and schema_name arguments to set_experiment.",
64
+ UserWarning,
65
+ stacklevel=2,
66
+ )
67
+ return
68
+
69
+ # Declare types for mypy
70
+ self._database_name: sql_identifier.SqlIdentifier
71
+ self._schema_name: sql_identifier.SqlIdentifier
72
+ self._sql_client: sql_client.ExperimentTrackingSQLClient
73
+
50
74
  if database_name:
51
75
  self._database_name = sql_identifier.SqlIdentifier(database_name)
52
76
  elif session_db := session.get_current_database():
@@ -78,6 +102,8 @@ class ExperimentTracking:
78
102
  # The run in context
79
103
  self._run: Optional[entities.Run] = None
80
104
 
105
+ self._initialized = True
106
+
81
107
  def __getstate__(self) -> dict[str, Any]:
82
108
  parent_state = (
83
109
  super().__getstate__() # type: ignore[misc] # object.__getstate__ appears in 3.11
@@ -116,19 +142,40 @@ class ExperimentTracking:
116
142
  def set_experiment(
117
143
  self,
118
144
  experiment_name: str,
145
+ database_name: Optional[str] = None,
146
+ schema_name: Optional[str] = None,
119
147
  ) -> entities.Experiment:
120
148
  """
121
149
  Set the experiment in context. Creates a new experiment if it doesn't exist.
122
150
 
123
151
  Args:
124
152
  experiment_name: The name of the experiment.
153
+ database_name: The name of the database. If None, reuse the current database. Defaults to None.
154
+ schema_name: The name of the schema. If None, the behavior depends on whether `database_name` is specified.
155
+ If `database_name` is specified, the schema is set to "PUBLIC".
156
+ If `database_name` is not specified, reuse the current schema. Defaults to None.
125
157
 
126
158
  Returns:
127
159
  Experiment: The experiment that was set.
128
160
  """
161
+ if database_name is not None:
162
+ if schema_name is None:
163
+ schema_name = "PUBLIC"
164
+ database_name = (
165
+ sql_identifier.SqlIdentifier(database_name) if database_name is not None else self._database_name
166
+ )
167
+ schema_name = sql_identifier.SqlIdentifier(schema_name) if schema_name is not None else self._schema_name
168
+
129
169
  experiment_name = sql_identifier.SqlIdentifier(experiment_name)
130
- if self._experiment and self._experiment.name == experiment_name:
170
+ if (
171
+ self._experiment
172
+ and self._experiment.name == experiment_name
173
+ and self._database_name == database_name
174
+ and self._schema_name == schema_name
175
+ ):
131
176
  return self._experiment
177
+
178
+ self._update_database_and_schema(database_name, schema_name)
132
179
  self._sql_client.create_experiment(
133
180
  experiment_name=experiment_name,
134
181
  creation_mode=sql_client_utils.CreationMode(if_not_exists=True),
@@ -140,15 +187,42 @@ class ExperimentTracking:
140
187
  def delete_experiment(
141
188
  self,
142
189
  experiment_name: str,
190
+ database_name: Optional[str] = None,
191
+ schema_name: Optional[str] = None,
143
192
  ) -> None:
144
193
  """
145
194
  Delete an experiment.
146
195
 
147
196
  Args:
148
197
  experiment_name: The name of the experiment.
198
+ database_name: The name of the database. If None, reuse the current database.
199
+ Must be specified if `schema_name` is specified. Defaults to None.
200
+ schema_name: The name of the schema. If None, reuse the current schema.
201
+ Must be specified if `database_name` is specified. Defaults to None.
202
+
203
+ Raises:
204
+ ValueError: If database_name is specified but schema_name is not.
149
205
  """
150
- self._sql_client.drop_experiment(experiment_name=sql_identifier.SqlIdentifier(experiment_name))
151
- if self._experiment and self._experiment.name == experiment_name:
206
+ if (database_name is None) ^ (schema_name is None): # if only one of database_name and schema_name is set
207
+ raise ValueError(
208
+ "If one of database_name and schema_name is specified, the other one must also be specified."
209
+ )
210
+ database_name = (
211
+ sql_identifier.SqlIdentifier(database_name) if database_name is not None else self._database_name
212
+ )
213
+ schema_name = sql_identifier.SqlIdentifier(schema_name) if schema_name is not None else self._schema_name
214
+
215
+ self._sql_client.drop_experiment(
216
+ database_name=database_name,
217
+ schema_name=schema_name,
218
+ experiment_name=sql_identifier.SqlIdentifier(experiment_name),
219
+ )
220
+ if (
221
+ self._experiment
222
+ and self._experiment.name == experiment_name
223
+ and self._database_name == database_name
224
+ and self._schema_name == schema_name
225
+ ):
152
226
  self._experiment = None
153
227
  self._run = None
154
228
 
@@ -451,6 +525,22 @@ class ExperimentTracking:
451
525
  return sql_identifier.SqlIdentifier(run_name)
452
526
  raise RuntimeError("Random run name generation failed.")
453
527
 
528
+ def _update_database_and_schema(
529
+ self, database_name: sql_identifier.SqlIdentifier, schema_name: sql_identifier.SqlIdentifier
530
+ ) -> None:
531
+ self._database_name = database_name
532
+ self._schema_name = schema_name
533
+ self._sql_client = sql_client.ExperimentTrackingSQLClient(
534
+ session=self._session,
535
+ database_name=database_name,
536
+ schema_name=schema_name,
537
+ )
538
+ self._registry = registry.Registry(
539
+ session=self._session,
540
+ database_name=database_name,
541
+ schema_name=schema_name,
542
+ )
543
+
454
544
  def _print_urls(
455
545
  self,
456
546
  experiment_name: sql_identifier.SqlIdentifier,
@@ -1,3 +1,4 @@
1
+ import numbers
1
2
  from typing import Any, Union
2
3
 
3
4
 
@@ -12,3 +13,8 @@ def flatten_nested_params(params: Union[list[Any], dict[str, Any]], prefix: str
12
13
  else:
13
14
  flat_params[new_prefix] = value
14
15
  return flat_params
16
+
17
+
18
+ def is_integer(value: Any) -> bool:
19
+ """Check if the given value is an integer, excluding booleans."""
20
+ return isinstance(value, numbers.Integral) and not isinstance(value, bool)