snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +71 -0
- snowflake/ml/_internal/utils/service_logger.py +4 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
- snowflake/ml/data/data_connector.py +43 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +3 -2
- snowflake/ml/dataset/dataset_reader.py +22 -6
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +5 -3
- snowflake/ml/jobs/_utils/query_helper.py +20 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
- snowflake/ml/jobs/_utils/spec_utils.py +21 -4
- snowflake/ml/jobs/decorators.py +18 -25
- snowflake/ml/jobs/job.py +137 -37
- snowflake/ml/jobs/manager.py +228 -153
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/model_ops.py +12 -3
- snowflake/ml/model/_client/ops/service_ops.py +324 -138
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +9 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +47 -15
- snowflake/ml/registry/registry.py +109 -64
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -337,13 +337,54 @@ def get_package_spec_with_supported_ops_only(req: requirements.Requirement) -> r
|
|
337
337
|
Returns:
|
338
338
|
A requirements.Requirement object with supported ops only
|
339
339
|
"""
|
340
|
+
|
341
|
+
if req.name == "numpy":
|
342
|
+
import numpy as np
|
343
|
+
|
344
|
+
package_specifiers = get_numpy_specifiers(req, version.Version(np.__version__).major)
|
345
|
+
else:
|
346
|
+
package_specifiers = [spec for spec in req.specifier if spec.operator in _SUPPORTED_PACKAGE_SPEC_OPS]
|
347
|
+
|
340
348
|
new_req = copy.deepcopy(req)
|
341
|
-
new_req.specifier = specifiers.SpecifierSet(
|
342
|
-
specifiers=",".join([str(spec) for spec in req.specifier if spec.operator in _SUPPORTED_PACKAGE_SPEC_OPS])
|
343
|
-
)
|
349
|
+
new_req.specifier = specifiers.SpecifierSet(specifiers=",".join([str(spec) for spec in package_specifiers]))
|
344
350
|
return new_req
|
345
351
|
|
346
352
|
|
353
|
+
def get_numpy_specifiers(
|
354
|
+
req: requirements.Requirement,
|
355
|
+
client_numpy_major_version: int,
|
356
|
+
) -> list[specifiers.Specifier]:
|
357
|
+
"""Get the package spec with supported ops only including ==, >=, <=, > and < based on the client numpy
|
358
|
+
major version.
|
359
|
+
|
360
|
+
Args:
|
361
|
+
req: A requirements.Requirement object showing the requirement.
|
362
|
+
client_numpy_major_version: The major version of numpy to be used.
|
363
|
+
|
364
|
+
Returns:
|
365
|
+
A list of specifiers with supported ops only
|
366
|
+
"""
|
367
|
+
req_specifiers = []
|
368
|
+
for org_spec in req.specifier:
|
369
|
+
# check specifier that provides upper bound
|
370
|
+
if org_spec.operator in ["<", "<="]:
|
371
|
+
client_version = version.Version(str(client_numpy_major_version))
|
372
|
+
org_spec_version = version.Version(org_spec.version)
|
373
|
+
# check if the client's numpy major version is less than the specifier's upper bound
|
374
|
+
# if so, pin to max possible client major version
|
375
|
+
if client_version.major < org_spec_version.major:
|
376
|
+
modified_spec = specifiers.Specifier(f"<{client_version.major + 1}")
|
377
|
+
req_specifiers.append(modified_spec)
|
378
|
+
else:
|
379
|
+
# use the original specifier
|
380
|
+
req_specifiers.append(org_spec)
|
381
|
+
else:
|
382
|
+
# use the original specifier
|
383
|
+
req_specifiers.append(org_spec)
|
384
|
+
|
385
|
+
return req_specifiers
|
386
|
+
|
387
|
+
|
347
388
|
def _relax_specifier_set(
|
348
389
|
specifier_set: specifiers.SpecifierSet, strategy: relax_version_strategy.RelaxVersionStrategy
|
349
390
|
) -> specifiers.SpecifierSet:
|
@@ -3,7 +3,9 @@ from contextlib import contextmanager
|
|
3
3
|
from typing import Any, Optional
|
4
4
|
|
5
5
|
from absl import logging
|
6
|
+
from packaging import version
|
6
7
|
|
8
|
+
from snowflake.ml import version as snowml_version
|
7
9
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
8
10
|
from snowflake.ml._internal.utils import query_result_checker
|
9
11
|
from snowflake.snowpark import (
|
@@ -12,7 +14,7 @@ from snowflake.snowpark import (
|
|
12
14
|
)
|
13
15
|
|
14
16
|
LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
|
15
|
-
INLINE_DEPLOYMENT_SPEC_PARAMETER = "
|
17
|
+
INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
|
16
18
|
|
17
19
|
|
18
20
|
class PlatformCapabilities:
|
@@ -67,7 +69,7 @@ class PlatformCapabilities:
|
|
67
69
|
cls.clear_mock_features()
|
68
70
|
|
69
71
|
def is_inlined_deployment_spec_enabled(self) -> bool:
|
70
|
-
return self.
|
72
|
+
return self._is_version_feature_enabled(INLINE_DEPLOYMENT_SPEC_PARAMETER)
|
71
73
|
|
72
74
|
def is_live_commit_enabled(self) -> bool:
|
73
75
|
return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
|
@@ -126,3 +128,51 @@ class PlatformCapabilities:
|
|
126
128
|
else:
|
127
129
|
raise ValueError(f"Invalid boolean string: {value} for feature {feature_name}")
|
128
130
|
raise ValueError(f"Invalid boolean feature value: {value} for feature {feature_name}")
|
131
|
+
|
132
|
+
def _get_version_feature(self, feature_name: str) -> version.Version:
|
133
|
+
"""Get a version feature value, returning a large version number on failure or missing feature.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
feature_name: The name of the feature to retrieve.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
version.Version: The parsed version, or a large version number (999.999.999) if parsing fails
|
140
|
+
or the feature is missing.
|
141
|
+
"""
|
142
|
+
# Large version number to use as fallback
|
143
|
+
large_version = version.Version("999.999.999")
|
144
|
+
|
145
|
+
value = self.features.get(feature_name)
|
146
|
+
if value is None:
|
147
|
+
logging.debug(f"Feature {feature_name} not found, returning large version number")
|
148
|
+
return large_version
|
149
|
+
|
150
|
+
try:
|
151
|
+
# Convert to string if it's not already
|
152
|
+
version_str = str(value)
|
153
|
+
return version.Version(version_str)
|
154
|
+
except (version.InvalidVersion, ValueError, TypeError) as e:
|
155
|
+
logging.debug(
|
156
|
+
f"Failed to parse version from feature {feature_name} with value '{value}': {e}. "
|
157
|
+
f"Returning large version number"
|
158
|
+
)
|
159
|
+
return large_version
|
160
|
+
|
161
|
+
def _is_version_feature_enabled(self, feature_name: str) -> bool:
|
162
|
+
"""Check if the current package version is greater than or equal to the version feature.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
feature_name: The name of the version feature to compare against.
|
166
|
+
|
167
|
+
Returns:
|
168
|
+
bool: True if current package version >= feature version, False otherwise.
|
169
|
+
"""
|
170
|
+
current_version = version.Version(snowml_version.VERSION)
|
171
|
+
feature_version = self._get_version_feature(feature_name)
|
172
|
+
|
173
|
+
result = current_version >= feature_version
|
174
|
+
logging.debug(
|
175
|
+
f"Version comparison for feature {feature_name}: "
|
176
|
+
f"current={current_version}, feature={feature_version}, enabled={result}"
|
177
|
+
)
|
178
|
+
return result
|
@@ -66,4 +66,4 @@ class LazyType(Generic[T]):
|
|
66
66
|
return False
|
67
67
|
|
68
68
|
|
69
|
-
LiteralNDArrayType = Union[npt.NDArray[np.int_], npt.NDArray[np.
|
69
|
+
LiteralNDArrayType = Union[npt.NDArray[np.int_], npt.NDArray[np.float64], npt.NDArray[np.str_], npt.NDArray[np.bool_]]
|
@@ -240,7 +240,7 @@ def get_schema_level_object_identifier(
|
|
240
240
|
"""
|
241
241
|
|
242
242
|
for identifier in (db, schema, object_name):
|
243
|
-
if identifier is not None and SF_IDENTIFIER_RE.
|
243
|
+
if identifier is not None and SF_IDENTIFIER_RE.fullmatch(identifier) is None:
|
244
244
|
raise ValueError(f"Invalid identifier {identifier}")
|
245
245
|
|
246
246
|
if others is None:
|
@@ -0,0 +1,71 @@
|
|
1
|
+
from typing import Any, Optional
|
2
|
+
|
3
|
+
from snowflake.ml._internal.utils import identifier
|
4
|
+
from snowflake.snowpark import session as snowpark_session
|
5
|
+
|
6
|
+
_SESSION_KEY = "_session"
|
7
|
+
_SESSION_ACCOUNT_KEY = "session$account"
|
8
|
+
_SESSION_ROLE_KEY = "session$role"
|
9
|
+
_SESSION_DATABASE_KEY = "session$database"
|
10
|
+
_SESSION_SCHEMA_KEY = "session$schema"
|
11
|
+
|
12
|
+
|
13
|
+
def _identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
|
14
|
+
saved_resolved = identifier.resolve_identifier(saved) if saved is not None else saved
|
15
|
+
current_resolved = identifier.resolve_identifier(current) if current is not None else current
|
16
|
+
return saved_resolved == current_resolved
|
17
|
+
|
18
|
+
|
19
|
+
class SerializableSessionMixin:
|
20
|
+
"""Mixin that provides pickling capabilities for objects with Snowpark sessions."""
|
21
|
+
|
22
|
+
def __getstate__(self) -> dict[str, Any]:
|
23
|
+
"""Customize pickling to exclude non-serializable session and related components."""
|
24
|
+
if hasattr(super(), "__getstate__"):
|
25
|
+
state: dict[str, Any] = super().__getstate__() # type: ignore[misc]
|
26
|
+
else:
|
27
|
+
state = self.__dict__.copy()
|
28
|
+
|
29
|
+
# Save session metadata for validation during unpickling
|
30
|
+
session = state.pop(_SESSION_KEY, None)
|
31
|
+
if session is not None:
|
32
|
+
state[_SESSION_ACCOUNT_KEY] = session.get_current_account()
|
33
|
+
state[_SESSION_ROLE_KEY] = session.get_current_role()
|
34
|
+
state[_SESSION_DATABASE_KEY] = session.get_current_database()
|
35
|
+
state[_SESSION_SCHEMA_KEY] = session.get_current_schema()
|
36
|
+
|
37
|
+
return state
|
38
|
+
|
39
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
40
|
+
"""Restore session from context during unpickling."""
|
41
|
+
saved_account = state.pop(_SESSION_ACCOUNT_KEY, None)
|
42
|
+
saved_role = state.pop(_SESSION_ROLE_KEY, None)
|
43
|
+
saved_database = state.pop(_SESSION_DATABASE_KEY, None)
|
44
|
+
saved_schema = state.pop(_SESSION_SCHEMA_KEY, None)
|
45
|
+
|
46
|
+
if hasattr(super(), "__setstate__"):
|
47
|
+
super().__setstate__(state) # type: ignore[misc]
|
48
|
+
else:
|
49
|
+
self.__dict__.update(state)
|
50
|
+
|
51
|
+
if saved_account is not None:
|
52
|
+
active_sessions = snowpark_session._get_active_sessions()
|
53
|
+
if len(active_sessions) == 0:
|
54
|
+
raise RuntimeError("No active Snowpark session available. Please create a session.")
|
55
|
+
|
56
|
+
# Best effort match: Find the session with the most matching identifiers
|
57
|
+
setattr(
|
58
|
+
self,
|
59
|
+
_SESSION_KEY,
|
60
|
+
max(
|
61
|
+
active_sessions,
|
62
|
+
key=lambda s: sum(
|
63
|
+
(
|
64
|
+
_identifiers_match(saved_account, s.get_current_account()),
|
65
|
+
_identifiers_match(saved_role, s.get_current_role()),
|
66
|
+
_identifiers_match(saved_database, s.get_current_database()),
|
67
|
+
_identifiers_match(saved_schema, s.get_current_schema()),
|
68
|
+
)
|
69
|
+
),
|
70
|
+
),
|
71
|
+
)
|
@@ -10,6 +10,10 @@ class LogColor(enum.Enum):
|
|
10
10
|
YELLOW = "\x1b[33;20m"
|
11
11
|
BLUE = "\x1b[34;20m"
|
12
12
|
GREEN = "\x1b[32;20m"
|
13
|
+
ORANGE = "\x1b[38;5;214m"
|
14
|
+
BOLD_ORANGE = "\x1b[38;5;214;1m"
|
15
|
+
PURPLE = "\x1b[35;20m"
|
16
|
+
BOLD_PURPLE = "\x1b[35;1m"
|
13
17
|
|
14
18
|
|
15
19
|
class CustomFormatter(logging.Formatter):
|
@@ -55,9 +59,7 @@ class CustomFormatter(logging.Formatter):
|
|
55
59
|
|
56
60
|
def get_logger(logger_name: str, info_color: LogColor) -> logging.Logger:
|
57
61
|
logger = logging.getLogger(logger_name)
|
58
|
-
logger.setLevel(logging.INFO)
|
59
62
|
handler = logging.StreamHandler(sys.stdout)
|
60
|
-
handler.setLevel(logging.INFO)
|
61
63
|
handler.setFormatter(CustomFormatter(info_color))
|
62
64
|
logger.addHandler(handler)
|
63
65
|
return logger
|
@@ -2,7 +2,7 @@ import collections
|
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
import time
|
5
|
-
from typing import Any, Deque, Iterator, Optional, Sequence, Union
|
5
|
+
from typing import TYPE_CHECKING, Any, Deque, Iterator, Optional, Sequence, Union
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import numpy.typing as npt
|
@@ -10,6 +10,9 @@ import pandas as pd
|
|
10
10
|
import pyarrow as pa
|
11
11
|
import pyarrow.dataset as pds
|
12
12
|
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
import ray
|
15
|
+
|
13
16
|
from snowflake import snowpark
|
14
17
|
from snowflake.ml.data import data_ingestor, data_source, ingestor_utils
|
15
18
|
|
@@ -70,6 +73,13 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
70
73
|
def from_sources(cls, session: snowpark.Session, sources: Sequence[data_source.DataSource]) -> "ArrowIngestor":
|
71
74
|
return cls(session, sources)
|
72
75
|
|
76
|
+
@classmethod
|
77
|
+
def from_ray_dataset(
|
78
|
+
cls,
|
79
|
+
ray_ds: "ray.data.Dataset",
|
80
|
+
) -> "ArrowIngestor":
|
81
|
+
raise NotImplementedError
|
82
|
+
|
73
83
|
@property
|
74
84
|
def data_sources(self) -> list[data_source.DataSource]:
|
75
85
|
return self._data_sources
|
@@ -6,6 +6,7 @@ from typing_extensions import deprecated
|
|
6
6
|
|
7
7
|
from snowflake import snowpark
|
8
8
|
from snowflake.ml._internal import env, telemetry
|
9
|
+
from snowflake.ml._internal.utils import mixins
|
9
10
|
from snowflake.ml.data import data_ingestor, data_source
|
10
11
|
from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
|
11
12
|
from snowflake.snowpark import context as sf_context
|
@@ -21,11 +22,13 @@ if TYPE_CHECKING:
|
|
21
22
|
from snowflake.ml import dataset
|
22
23
|
|
23
24
|
_PROJECT = "DataConnector"
|
25
|
+
_INGESTOR_KEY = "_ingestor"
|
26
|
+
_INGESTOR_SOURCES_KEY = "ingestor$sources"
|
24
27
|
|
25
28
|
DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
|
26
29
|
|
27
30
|
|
28
|
-
class DataConnector:
|
31
|
+
class DataConnector(mixins.SerializableSessionMixin):
|
29
32
|
"""Snowflake data reader which provides application integration connectors"""
|
30
33
|
|
31
34
|
DEFAULT_INGESTOR_CLASS: type[data_ingestor.DataIngestor] = ArrowIngestor
|
@@ -33,8 +36,11 @@ class DataConnector:
|
|
33
36
|
def __init__(
|
34
37
|
self,
|
35
38
|
ingestor: data_ingestor.DataIngestor,
|
39
|
+
*,
|
40
|
+
session: Optional[snowpark.Session] = None,
|
36
41
|
**kwargs: Any,
|
37
42
|
) -> None:
|
43
|
+
self._session = session
|
38
44
|
self._ingestor = ingestor
|
39
45
|
self._kwargs = kwargs
|
40
46
|
|
@@ -75,6 +81,17 @@ class DataConnector:
|
|
75
81
|
)
|
76
82
|
return cls.from_sources(ds._session, [source], ingestor_class=ingestor_class, **kwargs)
|
77
83
|
|
84
|
+
@classmethod
|
85
|
+
def from_ray_dataset(
|
86
|
+
cls: type[DataConnectorType],
|
87
|
+
ray_ds: "ray.data.Dataset",
|
88
|
+
ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
|
89
|
+
**kwargs: Any,
|
90
|
+
) -> DataConnectorType:
|
91
|
+
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
92
|
+
ray_ingestor = ingestor_class.from_ray_dataset(ray_ds=ray_ds)
|
93
|
+
return cls(ray_ingestor, **kwargs)
|
94
|
+
|
78
95
|
@classmethod
|
79
96
|
@telemetry.send_api_usage_telemetry(
|
80
97
|
project=_PROJECT,
|
@@ -90,7 +107,31 @@ class DataConnector:
|
|
90
107
|
) -> DataConnectorType:
|
91
108
|
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
92
109
|
ingestor = ingestor_class.from_sources(session, sources)
|
93
|
-
return cls(ingestor, **kwargs)
|
110
|
+
return cls(ingestor, **kwargs, session=session)
|
111
|
+
|
112
|
+
def __getstate__(self) -> dict[str, Any]:
|
113
|
+
"""Customize pickling to exclude non-serializable session and related components."""
|
114
|
+
if hasattr(super(), "__getstate__"):
|
115
|
+
state = super().__getstate__()
|
116
|
+
else:
|
117
|
+
state = self.__dict__.copy()
|
118
|
+
|
119
|
+
ingestor = state.pop(_INGESTOR_KEY)
|
120
|
+
state[_INGESTOR_SOURCES_KEY] = ingestor.data_sources
|
121
|
+
|
122
|
+
return state
|
123
|
+
|
124
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
125
|
+
"""Restore session from context during unpickling."""
|
126
|
+
data_sources = state.pop(_INGESTOR_SOURCES_KEY)
|
127
|
+
|
128
|
+
if hasattr(super(), "__setstate__"):
|
129
|
+
super().__setstate__(state)
|
130
|
+
else:
|
131
|
+
self.__dict__.update(state)
|
132
|
+
|
133
|
+
assert self._session is not None
|
134
|
+
self._ingestor = self.DEFAULT_INGESTOR_CLASS.from_sources(self._session, data_sources)
|
94
135
|
|
95
136
|
@property
|
96
137
|
def data_sources(self) -> list[data_source.DataSource]:
|
@@ -7,6 +7,7 @@ from snowflake.ml.data import data_source
|
|
7
7
|
|
8
8
|
if TYPE_CHECKING:
|
9
9
|
import pandas as pd
|
10
|
+
import ray
|
10
11
|
|
11
12
|
|
12
13
|
DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
|
@@ -19,6 +20,13 @@ class DataIngestor(Protocol):
|
|
19
20
|
) -> DataIngestorType:
|
20
21
|
raise NotImplementedError
|
21
22
|
|
23
|
+
@classmethod
|
24
|
+
def from_ray_dataset(
|
25
|
+
cls: type[DataIngestorType],
|
26
|
+
ray_ds: "ray.data.Dataset",
|
27
|
+
) -> DataIngestorType:
|
28
|
+
raise NotImplementedError
|
29
|
+
|
22
30
|
@property
|
23
31
|
def data_sources(self) -> list[data_source.DataSource]:
|
24
32
|
raise NotImplementedError
|
snowflake/ml/data/torch_utils.py
CHANGED
@@ -95,6 +95,6 @@ def _preprocess_array(
|
|
95
95
|
array_list = arr.tolist()
|
96
96
|
# If this is an array of arrays, convert the dtype to match the underlying array.
|
97
97
|
# Otherwise, if this is a numpy array of strings, convert the array to a list.
|
98
|
-
arr = np.array(array_list, dtype=arr.
|
98
|
+
arr = np.array(array_list, dtype=arr.item(0).dtype) if isinstance(arr.item(0), np.ndarray) else array_list
|
99
99
|
|
100
100
|
return arr
|
snowflake/ml/dataset/dataset.py
CHANGED
@@ -14,6 +14,7 @@ from snowflake.ml._internal.exceptions import (
|
|
14
14
|
from snowflake.ml._internal.utils import (
|
15
15
|
formatting,
|
16
16
|
identifier,
|
17
|
+
mixins,
|
17
18
|
query_result_checker,
|
18
19
|
snowpark_dataframe_utils,
|
19
20
|
)
|
@@ -27,7 +28,7 @@ _METADATA_MAX_QUERY_LENGTH = 10000
|
|
27
28
|
_DATASET_VERSION_NAME_COL = "version"
|
28
29
|
|
29
30
|
|
30
|
-
class DatasetVersion:
|
31
|
+
class DatasetVersion(mixins.SerializableSessionMixin):
|
31
32
|
"""Represents a version of a Snowflake Dataset"""
|
32
33
|
|
33
34
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
@@ -176,7 +177,7 @@ class Dataset(lineage_node.LineageNode):
|
|
176
177
|
original_exception=RuntimeError("No Dataset version selected."),
|
177
178
|
)
|
178
179
|
if self._reader is None:
|
179
|
-
self._reader = dataset_reader.DatasetReader.from_dataset(self
|
180
|
+
self._reader = dataset_reader.DatasetReader.from_dataset(self)
|
180
181
|
return self._reader
|
181
182
|
|
182
183
|
@staticmethod
|
@@ -1,8 +1,10 @@
|
|
1
1
|
from typing import Any, Optional
|
2
|
+
from warnings import warn
|
2
3
|
|
3
4
|
from snowflake import snowpark
|
4
5
|
from snowflake.ml._internal import telemetry
|
5
6
|
from snowflake.ml._internal.lineage import lineage_utils
|
7
|
+
from snowflake.ml._internal.utils import mixins
|
6
8
|
from snowflake.ml.data import data_connector, data_ingestor, data_source, ingestor_utils
|
7
9
|
from snowflake.ml.fileset import snowfs
|
8
10
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
@@ -11,7 +13,7 @@ _PROJECT = "Dataset"
|
|
11
13
|
_SUBPROJECT = "DatasetReader"
|
12
14
|
|
13
15
|
|
14
|
-
class DatasetReader(data_connector.DataConnector):
|
16
|
+
class DatasetReader(data_connector.DataConnector, mixins.SerializableSessionMixin):
|
15
17
|
"""Snowflake Dataset abstraction which provides application integration connectors"""
|
16
18
|
|
17
19
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
|
@@ -19,14 +21,26 @@ class DatasetReader(data_connector.DataConnector):
|
|
19
21
|
self,
|
20
22
|
ingestor: data_ingestor.DataIngestor,
|
21
23
|
*,
|
22
|
-
|
24
|
+
session: snowpark.Session,
|
25
|
+
snowpark_session: Optional[snowpark.Session] = None,
|
23
26
|
) -> None:
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
27
|
+
if snowpark_session is not None:
|
28
|
+
warn(
|
29
|
+
"Argument snowpark_session is deprecated and will be removed in a future release. Use session instead."
|
30
|
+
)
|
31
|
+
session = snowpark_session
|
32
|
+
super().__init__(ingestor, session=session)
|
33
|
+
|
34
|
+
self._fs_cached: Optional[snowfs.SnowFileSystem] = None
|
28
35
|
self._files: Optional[list[str]] = None
|
29
36
|
|
37
|
+
@property
|
38
|
+
def _fs(self) -> snowfs.SnowFileSystem:
|
39
|
+
if self._fs_cached is None:
|
40
|
+
assert self._session is not None
|
41
|
+
self._fs_cached = ingestor_utils.get_dataset_filesystem(self._session)
|
42
|
+
return self._fs_cached
|
43
|
+
|
30
44
|
@classmethod
|
31
45
|
def from_dataframe(
|
32
46
|
cls, df: snowpark.DataFrame, ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None, **kwargs: Any
|
@@ -42,6 +56,7 @@ class DatasetReader(data_connector.DataConnector):
|
|
42
56
|
files: list[str] = []
|
43
57
|
for source in self.data_sources:
|
44
58
|
assert isinstance(source, data_source.DatasetInfo)
|
59
|
+
assert self._session is not None
|
45
60
|
files.extend(ingestor_utils.get_dataset_files(self._session, source, filesystem=self._fs))
|
46
61
|
files.sort()
|
47
62
|
|
@@ -95,6 +110,7 @@ class DatasetReader(data_connector.DataConnector):
|
|
95
110
|
dfs: list[snowpark.DataFrame] = []
|
96
111
|
for source in self.data_sources:
|
97
112
|
assert isinstance(source, data_source.DatasetInfo) and source.url is not None
|
113
|
+
assert self._session is not None
|
98
114
|
stage_reader = self._session.read.option("pattern", file_path_pattern)
|
99
115
|
if "INFER_SCHEMA_OPTIONS" in snowpark_utils.NON_FORMAT_TYPE_OPTIONS:
|
100
116
|
stage_reader = stage_reader.option("INFER_SCHEMA_OPTIONS", {"MAX_FILE_COUNT": 1})
|
@@ -0,0 +1,98 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
|
+
from snowflake.ml.model._client.sql import _base
|
5
|
+
from snowflake.ml.utils import sql_client
|
6
|
+
from snowflake.snowpark import row, session
|
7
|
+
|
8
|
+
|
9
|
+
class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
10
|
+
|
11
|
+
RUN_NAME_COL_NAME = "name"
|
12
|
+
RUN_METADATA_COL_NAME = "metadata"
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
session: session.Session,
|
17
|
+
*,
|
18
|
+
database_name: sql_identifier.SqlIdentifier,
|
19
|
+
schema_name: sql_identifier.SqlIdentifier,
|
20
|
+
) -> None:
|
21
|
+
"""Snowflake SQL Client to manage experiment tracking.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
session: Active snowpark session.
|
25
|
+
database_name: Name of the Database where experiment tracking resources are provisioned.
|
26
|
+
schema_name: Name of the Schema where experiment tracking resources are provisioned.
|
27
|
+
"""
|
28
|
+
super().__init__(session, database_name=database_name, schema_name=schema_name)
|
29
|
+
|
30
|
+
def create_experiment(
|
31
|
+
self,
|
32
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
33
|
+
creation_mode: sql_client.CreationMode,
|
34
|
+
) -> None:
|
35
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
36
|
+
if_not_exists_sql = "IF NOT EXISTS" if creation_mode.if_not_exists else ""
|
37
|
+
query_result_checker.SqlResultValidator(
|
38
|
+
self._session, f"CREATE EXPERIMENT {if_not_exists_sql} {experiment_fqn}"
|
39
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
40
|
+
|
41
|
+
def drop_experiment(self, *, experiment_name: sql_identifier.SqlIdentifier) -> None:
|
42
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
43
|
+
query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
|
44
|
+
expected_rows=1, expected_cols=1
|
45
|
+
).validate()
|
46
|
+
|
47
|
+
def add_run(
|
48
|
+
self,
|
49
|
+
*,
|
50
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
51
|
+
run_name: sql_identifier.SqlIdentifier,
|
52
|
+
live: bool = True,
|
53
|
+
) -> None:
|
54
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
55
|
+
query_result_checker.SqlResultValidator(
|
56
|
+
self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD {'LIVE' if live else ''} RUN {run_name}"
|
57
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
58
|
+
|
59
|
+
def commit_run(
|
60
|
+
self,
|
61
|
+
*,
|
62
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
63
|
+
run_name: sql_identifier.SqlIdentifier,
|
64
|
+
) -> None:
|
65
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
66
|
+
query_result_checker.SqlResultValidator(
|
67
|
+
self._session, f"ALTER EXPERIMENT {experiment_fqn} COMMIT RUN {run_name}"
|
68
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
69
|
+
|
70
|
+
def drop_run(
|
71
|
+
self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier
|
72
|
+
) -> None:
|
73
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
74
|
+
query_result_checker.SqlResultValidator(
|
75
|
+
self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
|
76
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
77
|
+
|
78
|
+
def modify_run(
|
79
|
+
self,
|
80
|
+
*,
|
81
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
82
|
+
run_name: sql_identifier.SqlIdentifier,
|
83
|
+
run_metadata: str,
|
84
|
+
) -> None:
|
85
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
86
|
+
query_result_checker.SqlResultValidator(
|
87
|
+
self._session,
|
88
|
+
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} SET METADATA=$${run_metadata}$$",
|
89
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
90
|
+
|
91
|
+
def show_runs_in_experiment(
|
92
|
+
self, *, experiment_name: sql_identifier.SqlIdentifier, like: Optional[str] = None
|
93
|
+
) -> list[row.Row]:
|
94
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
95
|
+
like_clause = f"LIKE '{like}'" if like else ""
|
96
|
+
return query_result_checker.SqlResultValidator(
|
97
|
+
self._session, f"SHOW RUNS {like_clause} IN EXPERIMENT {experiment_fqn}"
|
98
|
+
).validate()
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import json
|
2
|
+
import types
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
4
|
+
|
5
|
+
from snowflake.ml._internal.utils import sql_identifier
|
6
|
+
from snowflake.ml.experiment import _experiment_info as experiment_info
|
7
|
+
from snowflake.ml.experiment._client import experiment_tracking_sql_client
|
8
|
+
from snowflake.ml.experiment._entities import run_metadata
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from snowflake.ml.experiment import experiment_tracking
|
12
|
+
|
13
|
+
|
14
|
+
class Run:
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
experiment_tracking: "experiment_tracking.ExperimentTracking",
|
18
|
+
*,
|
19
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
20
|
+
run_name: sql_identifier.SqlIdentifier,
|
21
|
+
) -> None:
|
22
|
+
self._experiment_tracking = experiment_tracking
|
23
|
+
self.experiment_name = experiment_name
|
24
|
+
self.name = run_name
|
25
|
+
|
26
|
+
self._patcher = experiment_info.ExperimentInfoPatcher(
|
27
|
+
experiment_info=self._get_experiment_info(),
|
28
|
+
)
|
29
|
+
|
30
|
+
def __enter__(self) -> "Run":
|
31
|
+
self._patcher.__enter__()
|
32
|
+
return self
|
33
|
+
|
34
|
+
def __exit__(
|
35
|
+
self,
|
36
|
+
exc_type: Optional[type[BaseException]],
|
37
|
+
exc_value: Optional[BaseException],
|
38
|
+
traceback: Optional[types.TracebackType],
|
39
|
+
) -> None:
|
40
|
+
self._patcher.__exit__(exc_type, exc_value, traceback)
|
41
|
+
if self._experiment_tracking._run is self:
|
42
|
+
self._experiment_tracking.end_run()
|
43
|
+
|
44
|
+
def _get_metadata(
|
45
|
+
self,
|
46
|
+
) -> run_metadata.RunMetadata:
|
47
|
+
runs = self._experiment_tracking._sql_client.show_runs_in_experiment(
|
48
|
+
experiment_name=self.experiment_name, like=str(self.name)
|
49
|
+
)
|
50
|
+
if not runs:
|
51
|
+
raise RuntimeError(f"Run {self.name} not found in experiment {self.experiment_name}.")
|
52
|
+
return run_metadata.RunMetadata.from_dict(
|
53
|
+
json.loads(runs[0][experiment_tracking_sql_client.ExperimentTrackingSQLClient.RUN_METADATA_COL_NAME])
|
54
|
+
)
|
55
|
+
|
56
|
+
def _get_experiment_info(self) -> experiment_info.ExperimentInfo:
|
57
|
+
return experiment_info.ExperimentInfo(
|
58
|
+
fully_qualified_name=self._experiment_tracking._sql_client.fully_qualified_object_name(
|
59
|
+
self._experiment_tracking._database_name, self._experiment_tracking._schema_name, self.experiment_name
|
60
|
+
),
|
61
|
+
run_name=self.name.identifier(),
|
62
|
+
)
|