snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/identifier.py +1 -1
  5. snowflake/ml/_internal/utils/mixins.py +71 -0
  6. snowflake/ml/_internal/utils/service_logger.py +4 -2
  7. snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
  8. snowflake/ml/data/data_connector.py +43 -2
  9. snowflake/ml/data/data_ingestor.py +8 -0
  10. snowflake/ml/data/torch_utils.py +1 -1
  11. snowflake/ml/dataset/dataset.py +3 -2
  12. snowflake/ml/dataset/dataset_reader.py +22 -6
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/experiment_tracking.py +319 -0
  20. snowflake/ml/jobs/_utils/constants.py +1 -1
  21. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +5 -3
  23. snowflake/ml/jobs/_utils/query_helper.py +20 -0
  24. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
  25. snowflake/ml/jobs/_utils/spec_utils.py +21 -4
  26. snowflake/ml/jobs/decorators.py +18 -25
  27. snowflake/ml/jobs/job.py +137 -37
  28. snowflake/ml/jobs/manager.py +228 -153
  29. snowflake/ml/lineage/lineage_node.py +2 -2
  30. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  31. snowflake/ml/model/_client/ops/model_ops.py +12 -3
  32. snowflake/ml/model/_client/ops/service_ops.py +324 -138
  33. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  34. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  35. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  38. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  39. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  40. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  45. snowflake/ml/model/event_handler.py +117 -0
  46. snowflake/ml/model/model_signature.py +9 -9
  47. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  48. snowflake/ml/model/target_platform.py +11 -0
  49. snowflake/ml/model/task.py +9 -0
  50. snowflake/ml/model/type_hints.py +5 -13
  51. snowflake/ml/modeling/framework/base.py +1 -1
  52. snowflake/ml/modeling/metrics/classification.py +14 -14
  53. snowflake/ml/modeling/metrics/correlation.py +19 -8
  54. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  55. snowflake/ml/modeling/metrics/ranking.py +6 -6
  56. snowflake/ml/modeling/metrics/regression.py +9 -9
  57. snowflake/ml/monitoring/explain_visualize.py +12 -5
  58. snowflake/ml/registry/_manager/model_manager.py +47 -15
  59. snowflake/ml/registry/registry.py +109 -64
  60. snowflake/ml/version.py +1 -1
  61. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
  62. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
  63. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
  64. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
  65. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -337,13 +337,54 @@ def get_package_spec_with_supported_ops_only(req: requirements.Requirement) -> r
337
337
  Returns:
338
338
  A requirements.Requirement object with supported ops only
339
339
  """
340
+
341
+ if req.name == "numpy":
342
+ import numpy as np
343
+
344
+ package_specifiers = get_numpy_specifiers(req, version.Version(np.__version__).major)
345
+ else:
346
+ package_specifiers = [spec for spec in req.specifier if spec.operator in _SUPPORTED_PACKAGE_SPEC_OPS]
347
+
340
348
  new_req = copy.deepcopy(req)
341
- new_req.specifier = specifiers.SpecifierSet(
342
- specifiers=",".join([str(spec) for spec in req.specifier if spec.operator in _SUPPORTED_PACKAGE_SPEC_OPS])
343
- )
349
+ new_req.specifier = specifiers.SpecifierSet(specifiers=",".join([str(spec) for spec in package_specifiers]))
344
350
  return new_req
345
351
 
346
352
 
353
+ def get_numpy_specifiers(
354
+ req: requirements.Requirement,
355
+ client_numpy_major_version: int,
356
+ ) -> list[specifiers.Specifier]:
357
+ """Get the package spec with supported ops only including ==, >=, <=, > and < based on the client numpy
358
+ major version.
359
+
360
+ Args:
361
+ req: A requirements.Requirement object showing the requirement.
362
+ client_numpy_major_version: The major version of numpy to be used.
363
+
364
+ Returns:
365
+ A list of specifiers with supported ops only
366
+ """
367
+ req_specifiers = []
368
+ for org_spec in req.specifier:
369
+ # check specifier that provides upper bound
370
+ if org_spec.operator in ["<", "<="]:
371
+ client_version = version.Version(str(client_numpy_major_version))
372
+ org_spec_version = version.Version(org_spec.version)
373
+ # check if the client's numpy major version is less than the specifier's upper bound
374
+ # if so, pin to max possible client major version
375
+ if client_version.major < org_spec_version.major:
376
+ modified_spec = specifiers.Specifier(f"<{client_version.major + 1}")
377
+ req_specifiers.append(modified_spec)
378
+ else:
379
+ # use the original specifier
380
+ req_specifiers.append(org_spec)
381
+ else:
382
+ # use the original specifier
383
+ req_specifiers.append(org_spec)
384
+
385
+ return req_specifiers
386
+
387
+
347
388
  def _relax_specifier_set(
348
389
  specifier_set: specifiers.SpecifierSet, strategy: relax_version_strategy.RelaxVersionStrategy
349
390
  ) -> specifiers.SpecifierSet:
@@ -3,7 +3,9 @@ from contextlib import contextmanager
3
3
  from typing import Any, Optional
4
4
 
5
5
  from absl import logging
6
+ from packaging import version
6
7
 
8
+ from snowflake.ml import version as snowml_version
7
9
  from snowflake.ml._internal.exceptions import error_codes, exceptions
8
10
  from snowflake.ml._internal.utils import query_result_checker
9
11
  from snowflake.snowpark import (
@@ -12,7 +14,7 @@ from snowflake.snowpark import (
12
14
  )
13
15
 
14
16
  LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
15
- INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC"
17
+ INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
16
18
 
17
19
 
18
20
  class PlatformCapabilities:
@@ -67,7 +69,7 @@ class PlatformCapabilities:
67
69
  cls.clear_mock_features()
68
70
 
69
71
  def is_inlined_deployment_spec_enabled(self) -> bool:
70
- return self._get_bool_feature(INLINE_DEPLOYMENT_SPEC_PARAMETER, False)
72
+ return self._is_version_feature_enabled(INLINE_DEPLOYMENT_SPEC_PARAMETER)
71
73
 
72
74
  def is_live_commit_enabled(self) -> bool:
73
75
  return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
@@ -126,3 +128,51 @@ class PlatformCapabilities:
126
128
  else:
127
129
  raise ValueError(f"Invalid boolean string: {value} for feature {feature_name}")
128
130
  raise ValueError(f"Invalid boolean feature value: {value} for feature {feature_name}")
131
+
132
+ def _get_version_feature(self, feature_name: str) -> version.Version:
133
+ """Get a version feature value, returning a large version number on failure or missing feature.
134
+
135
+ Args:
136
+ feature_name: The name of the feature to retrieve.
137
+
138
+ Returns:
139
+ version.Version: The parsed version, or a large version number (999.999.999) if parsing fails
140
+ or the feature is missing.
141
+ """
142
+ # Large version number to use as fallback
143
+ large_version = version.Version("999.999.999")
144
+
145
+ value = self.features.get(feature_name)
146
+ if value is None:
147
+ logging.debug(f"Feature {feature_name} not found, returning large version number")
148
+ return large_version
149
+
150
+ try:
151
+ # Convert to string if it's not already
152
+ version_str = str(value)
153
+ return version.Version(version_str)
154
+ except (version.InvalidVersion, ValueError, TypeError) as e:
155
+ logging.debug(
156
+ f"Failed to parse version from feature {feature_name} with value '{value}': {e}. "
157
+ f"Returning large version number"
158
+ )
159
+ return large_version
160
+
161
+ def _is_version_feature_enabled(self, feature_name: str) -> bool:
162
+ """Check if the current package version is greater than or equal to the version feature.
163
+
164
+ Args:
165
+ feature_name: The name of the version feature to compare against.
166
+
167
+ Returns:
168
+ bool: True if current package version >= feature version, False otherwise.
169
+ """
170
+ current_version = version.Version(snowml_version.VERSION)
171
+ feature_version = self._get_version_feature(feature_name)
172
+
173
+ result = current_version >= feature_version
174
+ logging.debug(
175
+ f"Version comparison for feature {feature_name}: "
176
+ f"current={current_version}, feature={feature_version}, enabled={result}"
177
+ )
178
+ return result
@@ -66,4 +66,4 @@ class LazyType(Generic[T]):
66
66
  return False
67
67
 
68
68
 
69
- LiteralNDArrayType = Union[npt.NDArray[np.int_], npt.NDArray[np.float_], npt.NDArray[np.str_], npt.NDArray[np.bool_]]
69
+ LiteralNDArrayType = Union[npt.NDArray[np.int_], npt.NDArray[np.float64], npt.NDArray[np.str_], npt.NDArray[np.bool_]]
@@ -240,7 +240,7 @@ def get_schema_level_object_identifier(
240
240
  """
241
241
 
242
242
  for identifier in (db, schema, object_name):
243
- if identifier is not None and SF_IDENTIFIER_RE.match(identifier) is None:
243
+ if identifier is not None and SF_IDENTIFIER_RE.fullmatch(identifier) is None:
244
244
  raise ValueError(f"Invalid identifier {identifier}")
245
245
 
246
246
  if others is None:
@@ -0,0 +1,71 @@
1
+ from typing import Any, Optional
2
+
3
+ from snowflake.ml._internal.utils import identifier
4
+ from snowflake.snowpark import session as snowpark_session
5
+
6
+ _SESSION_KEY = "_session"
7
+ _SESSION_ACCOUNT_KEY = "session$account"
8
+ _SESSION_ROLE_KEY = "session$role"
9
+ _SESSION_DATABASE_KEY = "session$database"
10
+ _SESSION_SCHEMA_KEY = "session$schema"
11
+
12
+
13
+ def _identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
14
+ saved_resolved = identifier.resolve_identifier(saved) if saved is not None else saved
15
+ current_resolved = identifier.resolve_identifier(current) if current is not None else current
16
+ return saved_resolved == current_resolved
17
+
18
+
19
+ class SerializableSessionMixin:
20
+ """Mixin that provides pickling capabilities for objects with Snowpark sessions."""
21
+
22
+ def __getstate__(self) -> dict[str, Any]:
23
+ """Customize pickling to exclude non-serializable session and related components."""
24
+ if hasattr(super(), "__getstate__"):
25
+ state: dict[str, Any] = super().__getstate__() # type: ignore[misc]
26
+ else:
27
+ state = self.__dict__.copy()
28
+
29
+ # Save session metadata for validation during unpickling
30
+ session = state.pop(_SESSION_KEY, None)
31
+ if session is not None:
32
+ state[_SESSION_ACCOUNT_KEY] = session.get_current_account()
33
+ state[_SESSION_ROLE_KEY] = session.get_current_role()
34
+ state[_SESSION_DATABASE_KEY] = session.get_current_database()
35
+ state[_SESSION_SCHEMA_KEY] = session.get_current_schema()
36
+
37
+ return state
38
+
39
+ def __setstate__(self, state: dict[str, Any]) -> None:
40
+ """Restore session from context during unpickling."""
41
+ saved_account = state.pop(_SESSION_ACCOUNT_KEY, None)
42
+ saved_role = state.pop(_SESSION_ROLE_KEY, None)
43
+ saved_database = state.pop(_SESSION_DATABASE_KEY, None)
44
+ saved_schema = state.pop(_SESSION_SCHEMA_KEY, None)
45
+
46
+ if hasattr(super(), "__setstate__"):
47
+ super().__setstate__(state) # type: ignore[misc]
48
+ else:
49
+ self.__dict__.update(state)
50
+
51
+ if saved_account is not None:
52
+ active_sessions = snowpark_session._get_active_sessions()
53
+ if len(active_sessions) == 0:
54
+ raise RuntimeError("No active Snowpark session available. Please create a session.")
55
+
56
+ # Best effort match: Find the session with the most matching identifiers
57
+ setattr(
58
+ self,
59
+ _SESSION_KEY,
60
+ max(
61
+ active_sessions,
62
+ key=lambda s: sum(
63
+ (
64
+ _identifiers_match(saved_account, s.get_current_account()),
65
+ _identifiers_match(saved_role, s.get_current_role()),
66
+ _identifiers_match(saved_database, s.get_current_database()),
67
+ _identifiers_match(saved_schema, s.get_current_schema()),
68
+ )
69
+ ),
70
+ ),
71
+ )
@@ -10,6 +10,10 @@ class LogColor(enum.Enum):
10
10
  YELLOW = "\x1b[33;20m"
11
11
  BLUE = "\x1b[34;20m"
12
12
  GREEN = "\x1b[32;20m"
13
+ ORANGE = "\x1b[38;5;214m"
14
+ BOLD_ORANGE = "\x1b[38;5;214;1m"
15
+ PURPLE = "\x1b[35;20m"
16
+ BOLD_PURPLE = "\x1b[35;1m"
13
17
 
14
18
 
15
19
  class CustomFormatter(logging.Formatter):
@@ -55,9 +59,7 @@ class CustomFormatter(logging.Formatter):
55
59
 
56
60
  def get_logger(logger_name: str, info_color: LogColor) -> logging.Logger:
57
61
  logger = logging.getLogger(logger_name)
58
- logger.setLevel(logging.INFO)
59
62
  handler = logging.StreamHandler(sys.stdout)
60
- handler.setLevel(logging.INFO)
61
63
  handler.setFormatter(CustomFormatter(info_color))
62
64
  logger.addHandler(handler)
63
65
  return logger
@@ -2,7 +2,7 @@ import collections
2
2
  import logging
3
3
  import os
4
4
  import time
5
- from typing import Any, Deque, Iterator, Optional, Sequence, Union
5
+ from typing import TYPE_CHECKING, Any, Deque, Iterator, Optional, Sequence, Union
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
@@ -10,6 +10,9 @@ import pandas as pd
10
10
  import pyarrow as pa
11
11
  import pyarrow.dataset as pds
12
12
 
13
+ if TYPE_CHECKING:
14
+ import ray
15
+
13
16
  from snowflake import snowpark
14
17
  from snowflake.ml.data import data_ingestor, data_source, ingestor_utils
15
18
 
@@ -70,6 +73,13 @@ class ArrowIngestor(data_ingestor.DataIngestor):
70
73
  def from_sources(cls, session: snowpark.Session, sources: Sequence[data_source.DataSource]) -> "ArrowIngestor":
71
74
  return cls(session, sources)
72
75
 
76
+ @classmethod
77
+ def from_ray_dataset(
78
+ cls,
79
+ ray_ds: "ray.data.Dataset",
80
+ ) -> "ArrowIngestor":
81
+ raise NotImplementedError
82
+
73
83
  @property
74
84
  def data_sources(self) -> list[data_source.DataSource]:
75
85
  return self._data_sources
@@ -6,6 +6,7 @@ from typing_extensions import deprecated
6
6
 
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml._internal import env, telemetry
9
+ from snowflake.ml._internal.utils import mixins
9
10
  from snowflake.ml.data import data_ingestor, data_source
10
11
  from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
11
12
  from snowflake.snowpark import context as sf_context
@@ -21,11 +22,13 @@ if TYPE_CHECKING:
21
22
  from snowflake.ml import dataset
22
23
 
23
24
  _PROJECT = "DataConnector"
25
+ _INGESTOR_KEY = "_ingestor"
26
+ _INGESTOR_SOURCES_KEY = "ingestor$sources"
24
27
 
25
28
  DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
26
29
 
27
30
 
28
- class DataConnector:
31
+ class DataConnector(mixins.SerializableSessionMixin):
29
32
  """Snowflake data reader which provides application integration connectors"""
30
33
 
31
34
  DEFAULT_INGESTOR_CLASS: type[data_ingestor.DataIngestor] = ArrowIngestor
@@ -33,8 +36,11 @@ class DataConnector:
33
36
  def __init__(
34
37
  self,
35
38
  ingestor: data_ingestor.DataIngestor,
39
+ *,
40
+ session: Optional[snowpark.Session] = None,
36
41
  **kwargs: Any,
37
42
  ) -> None:
43
+ self._session = session
38
44
  self._ingestor = ingestor
39
45
  self._kwargs = kwargs
40
46
 
@@ -75,6 +81,17 @@ class DataConnector:
75
81
  )
76
82
  return cls.from_sources(ds._session, [source], ingestor_class=ingestor_class, **kwargs)
77
83
 
84
+ @classmethod
85
+ def from_ray_dataset(
86
+ cls: type[DataConnectorType],
87
+ ray_ds: "ray.data.Dataset",
88
+ ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
89
+ **kwargs: Any,
90
+ ) -> DataConnectorType:
91
+ ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
92
+ ray_ingestor = ingestor_class.from_ray_dataset(ray_ds=ray_ds)
93
+ return cls(ray_ingestor, **kwargs)
94
+
78
95
  @classmethod
79
96
  @telemetry.send_api_usage_telemetry(
80
97
  project=_PROJECT,
@@ -90,7 +107,31 @@ class DataConnector:
90
107
  ) -> DataConnectorType:
91
108
  ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
92
109
  ingestor = ingestor_class.from_sources(session, sources)
93
- return cls(ingestor, **kwargs)
110
+ return cls(ingestor, **kwargs, session=session)
111
+
112
+ def __getstate__(self) -> dict[str, Any]:
113
+ """Customize pickling to exclude non-serializable session and related components."""
114
+ if hasattr(super(), "__getstate__"):
115
+ state = super().__getstate__()
116
+ else:
117
+ state = self.__dict__.copy()
118
+
119
+ ingestor = state.pop(_INGESTOR_KEY)
120
+ state[_INGESTOR_SOURCES_KEY] = ingestor.data_sources
121
+
122
+ return state
123
+
124
+ def __setstate__(self, state: dict[str, Any]) -> None:
125
+ """Restore session from context during unpickling."""
126
+ data_sources = state.pop(_INGESTOR_SOURCES_KEY)
127
+
128
+ if hasattr(super(), "__setstate__"):
129
+ super().__setstate__(state)
130
+ else:
131
+ self.__dict__.update(state)
132
+
133
+ assert self._session is not None
134
+ self._ingestor = self.DEFAULT_INGESTOR_CLASS.from_sources(self._session, data_sources)
94
135
 
95
136
  @property
96
137
  def data_sources(self) -> list[data_source.DataSource]:
@@ -7,6 +7,7 @@ from snowflake.ml.data import data_source
7
7
 
8
8
  if TYPE_CHECKING:
9
9
  import pandas as pd
10
+ import ray
10
11
 
11
12
 
12
13
  DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
@@ -19,6 +20,13 @@ class DataIngestor(Protocol):
19
20
  ) -> DataIngestorType:
20
21
  raise NotImplementedError
21
22
 
23
+ @classmethod
24
+ def from_ray_dataset(
25
+ cls: type[DataIngestorType],
26
+ ray_ds: "ray.data.Dataset",
27
+ ) -> DataIngestorType:
28
+ raise NotImplementedError
29
+
22
30
  @property
23
31
  def data_sources(self) -> list[data_source.DataSource]:
24
32
  raise NotImplementedError
@@ -95,6 +95,6 @@ def _preprocess_array(
95
95
  array_list = arr.tolist()
96
96
  # If this is an array of arrays, convert the dtype to match the underlying array.
97
97
  # Otherwise, if this is a numpy array of strings, convert the array to a list.
98
- arr = np.array(array_list, dtype=arr.flat[0].dtype) if isinstance(arr.flat[0], np.ndarray) else array_list
98
+ arr = np.array(array_list, dtype=arr.item(0).dtype) if isinstance(arr.item(0), np.ndarray) else array_list
99
99
 
100
100
  return arr
@@ -14,6 +14,7 @@ from snowflake.ml._internal.exceptions import (
14
14
  from snowflake.ml._internal.utils import (
15
15
  formatting,
16
16
  identifier,
17
+ mixins,
17
18
  query_result_checker,
18
19
  snowpark_dataframe_utils,
19
20
  )
@@ -27,7 +28,7 @@ _METADATA_MAX_QUERY_LENGTH = 10000
27
28
  _DATASET_VERSION_NAME_COL = "version"
28
29
 
29
30
 
30
- class DatasetVersion:
31
+ class DatasetVersion(mixins.SerializableSessionMixin):
31
32
  """Represents a version of a Snowflake Dataset"""
32
33
 
33
34
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
@@ -176,7 +177,7 @@ class Dataset(lineage_node.LineageNode):
176
177
  original_exception=RuntimeError("No Dataset version selected."),
177
178
  )
178
179
  if self._reader is None:
179
- self._reader = dataset_reader.DatasetReader.from_dataset(self, snowpark_session=self._session)
180
+ self._reader = dataset_reader.DatasetReader.from_dataset(self)
180
181
  return self._reader
181
182
 
182
183
  @staticmethod
@@ -1,8 +1,10 @@
1
1
  from typing import Any, Optional
2
+ from warnings import warn
2
3
 
3
4
  from snowflake import snowpark
4
5
  from snowflake.ml._internal import telemetry
5
6
  from snowflake.ml._internal.lineage import lineage_utils
7
+ from snowflake.ml._internal.utils import mixins
6
8
  from snowflake.ml.data import data_connector, data_ingestor, data_source, ingestor_utils
7
9
  from snowflake.ml.fileset import snowfs
8
10
  from snowflake.snowpark._internal import utils as snowpark_utils
@@ -11,7 +13,7 @@ _PROJECT = "Dataset"
11
13
  _SUBPROJECT = "DatasetReader"
12
14
 
13
15
 
14
- class DatasetReader(data_connector.DataConnector):
16
+ class DatasetReader(data_connector.DataConnector, mixins.SerializableSessionMixin):
15
17
  """Snowflake Dataset abstraction which provides application integration connectors"""
16
18
 
17
19
  @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
@@ -19,14 +21,26 @@ class DatasetReader(data_connector.DataConnector):
19
21
  self,
20
22
  ingestor: data_ingestor.DataIngestor,
21
23
  *,
22
- snowpark_session: snowpark.Session,
24
+ session: snowpark.Session,
25
+ snowpark_session: Optional[snowpark.Session] = None,
23
26
  ) -> None:
24
- super().__init__(ingestor)
25
-
26
- self._session: snowpark.Session = snowpark_session
27
- self._fs: snowfs.SnowFileSystem = ingestor_utils.get_dataset_filesystem(self._session)
27
+ if snowpark_session is not None:
28
+ warn(
29
+ "Argument snowpark_session is deprecated and will be removed in a future release. Use session instead."
30
+ )
31
+ session = snowpark_session
32
+ super().__init__(ingestor, session=session)
33
+
34
+ self._fs_cached: Optional[snowfs.SnowFileSystem] = None
28
35
  self._files: Optional[list[str]] = None
29
36
 
37
+ @property
38
+ def _fs(self) -> snowfs.SnowFileSystem:
39
+ if self._fs_cached is None:
40
+ assert self._session is not None
41
+ self._fs_cached = ingestor_utils.get_dataset_filesystem(self._session)
42
+ return self._fs_cached
43
+
30
44
  @classmethod
31
45
  def from_dataframe(
32
46
  cls, df: snowpark.DataFrame, ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None, **kwargs: Any
@@ -42,6 +56,7 @@ class DatasetReader(data_connector.DataConnector):
42
56
  files: list[str] = []
43
57
  for source in self.data_sources:
44
58
  assert isinstance(source, data_source.DatasetInfo)
59
+ assert self._session is not None
45
60
  files.extend(ingestor_utils.get_dataset_files(self._session, source, filesystem=self._fs))
46
61
  files.sort()
47
62
 
@@ -95,6 +110,7 @@ class DatasetReader(data_connector.DataConnector):
95
110
  dfs: list[snowpark.DataFrame] = []
96
111
  for source in self.data_sources:
97
112
  assert isinstance(source, data_source.DatasetInfo) and source.url is not None
113
+ assert self._session is not None
98
114
  stage_reader = self._session.read.option("pattern", file_path_pattern)
99
115
  if "INFER_SCHEMA_OPTIONS" in snowpark_utils.NON_FORMAT_TYPE_OPTIONS:
100
116
  stage_reader = stage_reader.option("INFER_SCHEMA_OPTIONS", {"MAX_FILE_COUNT": 1})
@@ -0,0 +1,98 @@
1
+ from typing import Optional
2
+
3
+ from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
+ from snowflake.ml.model._client.sql import _base
5
+ from snowflake.ml.utils import sql_client
6
+ from snowflake.snowpark import row, session
7
+
8
+
9
+ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
10
+
11
+ RUN_NAME_COL_NAME = "name"
12
+ RUN_METADATA_COL_NAME = "metadata"
13
+
14
+ def __init__(
15
+ self,
16
+ session: session.Session,
17
+ *,
18
+ database_name: sql_identifier.SqlIdentifier,
19
+ schema_name: sql_identifier.SqlIdentifier,
20
+ ) -> None:
21
+ """Snowflake SQL Client to manage experiment tracking.
22
+
23
+ Args:
24
+ session: Active snowpark session.
25
+ database_name: Name of the Database where experiment tracking resources are provisioned.
26
+ schema_name: Name of the Schema where experiment tracking resources are provisioned.
27
+ """
28
+ super().__init__(session, database_name=database_name, schema_name=schema_name)
29
+
30
+ def create_experiment(
31
+ self,
32
+ experiment_name: sql_identifier.SqlIdentifier,
33
+ creation_mode: sql_client.CreationMode,
34
+ ) -> None:
35
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
36
+ if_not_exists_sql = "IF NOT EXISTS" if creation_mode.if_not_exists else ""
37
+ query_result_checker.SqlResultValidator(
38
+ self._session, f"CREATE EXPERIMENT {if_not_exists_sql} {experiment_fqn}"
39
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
40
+
41
+ def drop_experiment(self, *, experiment_name: sql_identifier.SqlIdentifier) -> None:
42
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
43
+ query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
44
+ expected_rows=1, expected_cols=1
45
+ ).validate()
46
+
47
+ def add_run(
48
+ self,
49
+ *,
50
+ experiment_name: sql_identifier.SqlIdentifier,
51
+ run_name: sql_identifier.SqlIdentifier,
52
+ live: bool = True,
53
+ ) -> None:
54
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
55
+ query_result_checker.SqlResultValidator(
56
+ self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD {'LIVE' if live else ''} RUN {run_name}"
57
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
58
+
59
+ def commit_run(
60
+ self,
61
+ *,
62
+ experiment_name: sql_identifier.SqlIdentifier,
63
+ run_name: sql_identifier.SqlIdentifier,
64
+ ) -> None:
65
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
66
+ query_result_checker.SqlResultValidator(
67
+ self._session, f"ALTER EXPERIMENT {experiment_fqn} COMMIT RUN {run_name}"
68
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
69
+
70
+ def drop_run(
71
+ self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier
72
+ ) -> None:
73
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
74
+ query_result_checker.SqlResultValidator(
75
+ self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
76
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
77
+
78
+ def modify_run(
79
+ self,
80
+ *,
81
+ experiment_name: sql_identifier.SqlIdentifier,
82
+ run_name: sql_identifier.SqlIdentifier,
83
+ run_metadata: str,
84
+ ) -> None:
85
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
86
+ query_result_checker.SqlResultValidator(
87
+ self._session,
88
+ f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} SET METADATA=$${run_metadata}$$",
89
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
90
+
91
+ def show_runs_in_experiment(
92
+ self, *, experiment_name: sql_identifier.SqlIdentifier, like: Optional[str] = None
93
+ ) -> list[row.Row]:
94
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
95
+ like_clause = f"LIKE '{like}'" if like else ""
96
+ return query_result_checker.SqlResultValidator(
97
+ self._session, f"SHOW RUNS {like_clause} IN EXPERIMENT {experiment_fqn}"
98
+ ).validate()
@@ -0,0 +1,4 @@
1
+ from snowflake.ml.experiment._entities.experiment import Experiment
2
+ from snowflake.ml.experiment._entities.run import Run
3
+
4
+ __all__ = ["Experiment", "Run"]
@@ -0,0 +1,10 @@
1
+ from snowflake.ml._internal.utils import sql_identifier
2
+
3
+
4
+ class Experiment:
5
+ def __init__(
6
+ self,
7
+ *,
8
+ experiment_name: sql_identifier.SqlIdentifier,
9
+ ) -> None:
10
+ self.name = experiment_name
@@ -0,0 +1,62 @@
1
+ import json
2
+ import types
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ from snowflake.ml._internal.utils import sql_identifier
6
+ from snowflake.ml.experiment import _experiment_info as experiment_info
7
+ from snowflake.ml.experiment._client import experiment_tracking_sql_client
8
+ from snowflake.ml.experiment._entities import run_metadata
9
+
10
+ if TYPE_CHECKING:
11
+ from snowflake.ml.experiment import experiment_tracking
12
+
13
+
14
+ class Run:
15
+ def __init__(
16
+ self,
17
+ experiment_tracking: "experiment_tracking.ExperimentTracking",
18
+ *,
19
+ experiment_name: sql_identifier.SqlIdentifier,
20
+ run_name: sql_identifier.SqlIdentifier,
21
+ ) -> None:
22
+ self._experiment_tracking = experiment_tracking
23
+ self.experiment_name = experiment_name
24
+ self.name = run_name
25
+
26
+ self._patcher = experiment_info.ExperimentInfoPatcher(
27
+ experiment_info=self._get_experiment_info(),
28
+ )
29
+
30
+ def __enter__(self) -> "Run":
31
+ self._patcher.__enter__()
32
+ return self
33
+
34
+ def __exit__(
35
+ self,
36
+ exc_type: Optional[type[BaseException]],
37
+ exc_value: Optional[BaseException],
38
+ traceback: Optional[types.TracebackType],
39
+ ) -> None:
40
+ self._patcher.__exit__(exc_type, exc_value, traceback)
41
+ if self._experiment_tracking._run is self:
42
+ self._experiment_tracking.end_run()
43
+
44
+ def _get_metadata(
45
+ self,
46
+ ) -> run_metadata.RunMetadata:
47
+ runs = self._experiment_tracking._sql_client.show_runs_in_experiment(
48
+ experiment_name=self.experiment_name, like=str(self.name)
49
+ )
50
+ if not runs:
51
+ raise RuntimeError(f"Run {self.name} not found in experiment {self.experiment_name}.")
52
+ return run_metadata.RunMetadata.from_dict(
53
+ json.loads(runs[0][experiment_tracking_sql_client.ExperimentTrackingSQLClient.RUN_METADATA_COL_NAME])
54
+ )
55
+
56
+ def _get_experiment_info(self) -> experiment_info.ExperimentInfo:
57
+ return experiment_info.ExperimentInfo(
58
+ fully_qualified_name=self._experiment_tracking._sql_client.fully_qualified_object_name(
59
+ self._experiment_tracking._database_name, self._experiment_tracking._schema_name, self.experiment_name
60
+ ),
61
+ run_name=self.name.identifier(),
62
+ )