snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +71 -0
- snowflake/ml/_internal/utils/service_logger.py +4 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
- snowflake/ml/data/data_connector.py +43 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +3 -2
- snowflake/ml/dataset/dataset_reader.py +22 -6
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +5 -3
- snowflake/ml/jobs/_utils/query_helper.py +20 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
- snowflake/ml/jobs/_utils/spec_utils.py +21 -4
- snowflake/ml/jobs/decorators.py +18 -25
- snowflake/ml/jobs/job.py +137 -37
- snowflake/ml/jobs/manager.py +228 -153
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/model_ops.py +12 -3
- snowflake/ml/model/_client/ops/service_ops.py +324 -138
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +9 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +47 -15
- snowflake/ml/registry/registry.py +109 -64
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,68 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import enum
|
3
|
+
import typing
|
4
|
+
|
5
|
+
|
6
|
+
class RunStatus(str, enum.Enum):
|
7
|
+
UNKNOWN = "UNKNOWN"
|
8
|
+
RUNNING = "RUNNING"
|
9
|
+
FINISHED = "FINISHED"
|
10
|
+
|
11
|
+
|
12
|
+
@dataclasses.dataclass
|
13
|
+
class Metric:
|
14
|
+
name: str
|
15
|
+
value: float
|
16
|
+
step: int
|
17
|
+
|
18
|
+
|
19
|
+
@dataclasses.dataclass
|
20
|
+
class Param:
|
21
|
+
name: str
|
22
|
+
value: str
|
23
|
+
|
24
|
+
|
25
|
+
@dataclasses.dataclass
|
26
|
+
class RunMetadata:
|
27
|
+
status: RunStatus
|
28
|
+
metrics: list[Metric]
|
29
|
+
parameters: list[Param]
|
30
|
+
|
31
|
+
@classmethod
|
32
|
+
def from_dict(
|
33
|
+
cls,
|
34
|
+
metadata: dict, # type: ignore[type-arg]
|
35
|
+
) -> "RunMetadata":
|
36
|
+
return RunMetadata(
|
37
|
+
status=RunStatus(metadata.get("status", RunStatus.UNKNOWN.value)),
|
38
|
+
metrics=[Metric(**m) for m in metadata.get("metrics", [])],
|
39
|
+
parameters=[Param(**p) for p in metadata.get("parameters", [])],
|
40
|
+
)
|
41
|
+
|
42
|
+
def to_dict(self) -> dict: # type: ignore[type-arg]
|
43
|
+
return dataclasses.asdict(self)
|
44
|
+
|
45
|
+
def set_metric(
|
46
|
+
self,
|
47
|
+
key: str,
|
48
|
+
value: float,
|
49
|
+
step: int,
|
50
|
+
) -> None:
|
51
|
+
for metric in self.metrics:
|
52
|
+
if metric.name == key and metric.step == step:
|
53
|
+
metric.value = value
|
54
|
+
break
|
55
|
+
else:
|
56
|
+
self.metrics.append(Metric(name=key, value=value, step=step))
|
57
|
+
|
58
|
+
def set_param(
|
59
|
+
self,
|
60
|
+
key: str,
|
61
|
+
value: typing.Any,
|
62
|
+
) -> None:
|
63
|
+
for parameter in self.parameters:
|
64
|
+
if parameter.name == key:
|
65
|
+
parameter.value = str(value)
|
66
|
+
break
|
67
|
+
else:
|
68
|
+
self.parameters.append(Param(name=key, value=str(value)))
|
@@ -0,0 +1,63 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import functools
|
3
|
+
import types
|
4
|
+
from typing import Callable, Optional
|
5
|
+
|
6
|
+
from snowflake.ml import model
|
7
|
+
from snowflake.ml.registry._manager import model_manager
|
8
|
+
|
9
|
+
|
10
|
+
@dataclasses.dataclass(frozen=True)
|
11
|
+
class ExperimentInfo:
|
12
|
+
"""Serializable information identifying a Experiment"""
|
13
|
+
|
14
|
+
fully_qualified_name: str
|
15
|
+
run_name: str
|
16
|
+
|
17
|
+
|
18
|
+
class ExperimentInfoPatcher:
|
19
|
+
"""Context manager that patches ModelManager.log_model to include experiment information.
|
20
|
+
|
21
|
+
This class maintains a stack of active experiment contexts and ensures that
|
22
|
+
log_model calls are automatically tagged with the appropriate experiment info.
|
23
|
+
"""
|
24
|
+
|
25
|
+
# Store original method at class definition time to avoid recursive patching
|
26
|
+
_original_log_model: Callable[..., model.ModelVersion] = model_manager.ModelManager.log_model
|
27
|
+
|
28
|
+
# Stack of active experiment_info contexts for nested experiment support
|
29
|
+
_experiment_info_stack: list[ExperimentInfo] = []
|
30
|
+
|
31
|
+
def __init__(self, experiment_info: ExperimentInfo) -> None:
|
32
|
+
self._experiment_info = experiment_info
|
33
|
+
|
34
|
+
def __enter__(self) -> "ExperimentInfoPatcher":
|
35
|
+
# Only patch ModelManager.log_model if we're the first patcher to avoid nested patching
|
36
|
+
if not ExperimentInfoPatcher._experiment_info_stack:
|
37
|
+
|
38
|
+
@functools.wraps(ExperimentInfoPatcher._original_log_model)
|
39
|
+
def patched(*args, **kwargs) -> model.ModelVersion: # type: ignore[no-untyped-def]
|
40
|
+
# Use the most recent (top of stack) experiment_info for nested contexts
|
41
|
+
current_experiment_info = ExperimentInfoPatcher._experiment_info_stack[-1]
|
42
|
+
return ExperimentInfoPatcher._original_log_model(
|
43
|
+
*args, **kwargs, experiment_info=current_experiment_info
|
44
|
+
)
|
45
|
+
|
46
|
+
model_manager.ModelManager.log_model = patched # type: ignore[method-assign]
|
47
|
+
|
48
|
+
ExperimentInfoPatcher._experiment_info_stack.append(self._experiment_info)
|
49
|
+
return self
|
50
|
+
|
51
|
+
def __exit__(
|
52
|
+
self,
|
53
|
+
exc_type: Optional[type[BaseException]],
|
54
|
+
exc_value: Optional[BaseException],
|
55
|
+
traceback: Optional[types.TracebackType],
|
56
|
+
) -> None:
|
57
|
+
ExperimentInfoPatcher._experiment_info_stack.pop()
|
58
|
+
|
59
|
+
# Restore original method when no patches are active to clean up properly
|
60
|
+
if not ExperimentInfoPatcher._experiment_info_stack:
|
61
|
+
model_manager.ModelManager.log_model = ( # type: ignore[method-assign]
|
62
|
+
ExperimentInfoPatcher._original_log_model
|
63
|
+
)
|
@@ -0,0 +1,319 @@
|
|
1
|
+
import functools
|
2
|
+
import json
|
3
|
+
import sys
|
4
|
+
from typing import Any, Optional, Union
|
5
|
+
from urllib.parse import quote
|
6
|
+
|
7
|
+
import snowflake.snowpark._internal.utils as snowpark_utils
|
8
|
+
from snowflake.ml import model, registry
|
9
|
+
from snowflake.ml._internal.human_readable_id import hrid_generator
|
10
|
+
from snowflake.ml._internal.utils import sql_identifier
|
11
|
+
from snowflake.ml.experiment import (
|
12
|
+
_entities as entities,
|
13
|
+
_experiment_info as experiment_info,
|
14
|
+
)
|
15
|
+
from snowflake.ml.experiment._client import experiment_tracking_sql_client as sql_client
|
16
|
+
from snowflake.ml.model import type_hints
|
17
|
+
from snowflake.ml.utils import sql_client as sql_client_utils
|
18
|
+
from snowflake.snowpark import session
|
19
|
+
|
20
|
+
DEFAULT_EXPERIMENT_NAME = sql_identifier.SqlIdentifier("DEFAULT")
|
21
|
+
|
22
|
+
|
23
|
+
class ExperimentTracking:
|
24
|
+
"""
|
25
|
+
Class to manage experiments in Snowflake.
|
26
|
+
"""
|
27
|
+
|
28
|
+
@snowpark_utils.private_preview(version="1.9.1")
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
session: session.Session,
|
32
|
+
*,
|
33
|
+
database_name: Optional[str] = None,
|
34
|
+
schema_name: Optional[str] = None,
|
35
|
+
) -> None:
|
36
|
+
"""
|
37
|
+
Initializes experiment tracking within a pre-created schema.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
session: The Snowpark Session to connect with Snowflake.
|
41
|
+
database_name: The name of the database. If None, the current database of the session
|
42
|
+
will be used. Defaults to None.
|
43
|
+
schema_name: The name of the schema. If None, the current schema of the session
|
44
|
+
will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.
|
45
|
+
|
46
|
+
Raises:
|
47
|
+
ValueError: If no database is provided and no active database exists in the session.
|
48
|
+
"""
|
49
|
+
if database_name:
|
50
|
+
self._database_name = sql_identifier.SqlIdentifier(database_name)
|
51
|
+
elif session_db := session.get_current_database():
|
52
|
+
self._database_name = sql_identifier.SqlIdentifier(session_db)
|
53
|
+
else:
|
54
|
+
raise ValueError("You need to provide a database to use experiment tracking.")
|
55
|
+
|
56
|
+
if schema_name:
|
57
|
+
self._schema_name = sql_identifier.SqlIdentifier(schema_name)
|
58
|
+
elif session_schema := session.get_current_schema():
|
59
|
+
self._schema_name = sql_identifier.SqlIdentifier(session_schema)
|
60
|
+
else:
|
61
|
+
self._schema_name = sql_identifier.SqlIdentifier("PUBLIC")
|
62
|
+
|
63
|
+
self._sql_client = sql_client.ExperimentTrackingSQLClient(
|
64
|
+
session,
|
65
|
+
database_name=self._database_name,
|
66
|
+
schema_name=self._schema_name,
|
67
|
+
)
|
68
|
+
self._registry = registry.Registry(
|
69
|
+
session=session,
|
70
|
+
database_name=self._database_name,
|
71
|
+
schema_name=self._schema_name,
|
72
|
+
)
|
73
|
+
|
74
|
+
# The experiment in context
|
75
|
+
self._experiment: Optional[entities.Experiment] = None
|
76
|
+
# The run in context
|
77
|
+
self._run: Optional[entities.Run] = None
|
78
|
+
|
79
|
+
def set_experiment(
|
80
|
+
self,
|
81
|
+
experiment_name: str,
|
82
|
+
) -> entities.Experiment:
|
83
|
+
"""
|
84
|
+
Set the experiment in context. Creates a new experiment if it doesn't exist.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
experiment_name: The name of the experiment.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
Experiment: The experiment that was set.
|
91
|
+
"""
|
92
|
+
experiment_name = sql_identifier.SqlIdentifier(experiment_name)
|
93
|
+
if self._experiment and self._experiment.name == experiment_name:
|
94
|
+
return self._experiment
|
95
|
+
self._sql_client.create_experiment(
|
96
|
+
experiment_name=experiment_name,
|
97
|
+
creation_mode=sql_client_utils.CreationMode(if_not_exists=True),
|
98
|
+
)
|
99
|
+
self._experiment = entities.Experiment(experiment_name=experiment_name)
|
100
|
+
self._run = None
|
101
|
+
return self._experiment
|
102
|
+
|
103
|
+
def delete_experiment(
|
104
|
+
self,
|
105
|
+
experiment_name: str,
|
106
|
+
) -> None:
|
107
|
+
"""
|
108
|
+
Delete an experiment.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
experiment_name: The name of the experiment.
|
112
|
+
"""
|
113
|
+
self._sql_client.drop_experiment(experiment_name=sql_identifier.SqlIdentifier(experiment_name))
|
114
|
+
if self._experiment and self._experiment.name == experiment_name:
|
115
|
+
self._experiment = None
|
116
|
+
self._run = None
|
117
|
+
|
118
|
+
@functools.wraps(registry.Registry.log_model)
|
119
|
+
def log_model(
|
120
|
+
self,
|
121
|
+
model: Union[type_hints.SupportedModelType, model.ModelVersion],
|
122
|
+
*,
|
123
|
+
model_name: str,
|
124
|
+
**kwargs: Any,
|
125
|
+
) -> model.ModelVersion:
|
126
|
+
run = self._get_or_start_run()
|
127
|
+
with experiment_info.ExperimentInfoPatcher(experiment_info=run._get_experiment_info()):
|
128
|
+
return self._registry.log_model(model, model_name=model_name, **kwargs)
|
129
|
+
|
130
|
+
def start_run(
|
131
|
+
self,
|
132
|
+
run_name: Optional[str] = None,
|
133
|
+
) -> entities.Run:
|
134
|
+
"""
|
135
|
+
Start a new run.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
run_name: The name of the run. If None, a default name will be generated.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
Run: The run that was started.
|
142
|
+
|
143
|
+
Raises:
|
144
|
+
RuntimeError: If a run is already active.
|
145
|
+
"""
|
146
|
+
if self._run:
|
147
|
+
raise RuntimeError("A run is already active. Please end the current run before starting a new one.")
|
148
|
+
experiment = self._get_or_set_experiment()
|
149
|
+
run_name = (
|
150
|
+
sql_identifier.SqlIdentifier(run_name) if run_name is not None else self._generate_run_name(experiment)
|
151
|
+
)
|
152
|
+
self._sql_client.add_run(
|
153
|
+
experiment_name=experiment.name,
|
154
|
+
run_name=run_name,
|
155
|
+
)
|
156
|
+
self._run = entities.Run(experiment_tracking=self, experiment_name=experiment.name, run_name=run_name)
|
157
|
+
return self._run
|
158
|
+
|
159
|
+
def end_run(self, run_name: Optional[str] = None) -> None:
|
160
|
+
"""
|
161
|
+
End the current run if no run name is provided. Otherwise, the specified run is ended.
|
162
|
+
|
163
|
+
Args:
|
164
|
+
run_name: The name of the run to be ended. If None, the current run is ended.
|
165
|
+
|
166
|
+
Raises:
|
167
|
+
RuntimeError: If no run is active.
|
168
|
+
"""
|
169
|
+
if not self._experiment:
|
170
|
+
raise RuntimeError("No experiment set. Please set an experiment before ending a run.")
|
171
|
+
experiment_name = self._experiment.name
|
172
|
+
|
173
|
+
if run_name:
|
174
|
+
run_name = sql_identifier.SqlIdentifier(run_name)
|
175
|
+
elif self._run:
|
176
|
+
run_name = self._run.name
|
177
|
+
else:
|
178
|
+
raise RuntimeError("No run is active. Please start a run before ending it.")
|
179
|
+
|
180
|
+
self._sql_client.commit_run(
|
181
|
+
experiment_name=experiment_name,
|
182
|
+
run_name=run_name,
|
183
|
+
)
|
184
|
+
if self._run and run_name == self._run.name:
|
185
|
+
self._run = None
|
186
|
+
self._print_urls(experiment_name=experiment_name, run_name=run_name)
|
187
|
+
|
188
|
+
def delete_run(
|
189
|
+
self,
|
190
|
+
run_name: str,
|
191
|
+
) -> None:
|
192
|
+
"""
|
193
|
+
Delete a run.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
run_name: The name of the run to be deleted.
|
197
|
+
|
198
|
+
Raises:
|
199
|
+
RuntimeError: If no experiment is set.
|
200
|
+
"""
|
201
|
+
if not self._experiment:
|
202
|
+
raise RuntimeError("No experiment set. Please set an experiment before deleting a run.")
|
203
|
+
self._sql_client.drop_run(
|
204
|
+
experiment_name=self._experiment.name,
|
205
|
+
run_name=sql_identifier.SqlIdentifier(run_name),
|
206
|
+
)
|
207
|
+
if self._run and self._run.name == run_name:
|
208
|
+
self._run = None
|
209
|
+
|
210
|
+
def log_metric(
|
211
|
+
self,
|
212
|
+
key: str,
|
213
|
+
value: float,
|
214
|
+
step: int = 0,
|
215
|
+
) -> None:
|
216
|
+
"""
|
217
|
+
Log a metric under the current run. If no run is active, this method will create a new run.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
key: The name of the metric.
|
221
|
+
value: The value of the metric.
|
222
|
+
step: The step of the metric. Defaults to 0.
|
223
|
+
"""
|
224
|
+
self.log_metrics(metrics={key: value}, step=step)
|
225
|
+
|
226
|
+
def log_metrics(
|
227
|
+
self,
|
228
|
+
metrics: dict[str, float],
|
229
|
+
step: int = 0,
|
230
|
+
) -> None:
|
231
|
+
"""
|
232
|
+
Log metrics under the current run. If no run is active, this method will create a new run.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
metrics: Dictionary containing metric keys and float values.
|
236
|
+
step: The step of the metrics. Defaults to 0.
|
237
|
+
"""
|
238
|
+
run = self._get_or_start_run()
|
239
|
+
metadata = run._get_metadata()
|
240
|
+
for key, value in metrics.items():
|
241
|
+
metadata.set_metric(key, value, step)
|
242
|
+
self._sql_client.modify_run(
|
243
|
+
experiment_name=run.experiment_name,
|
244
|
+
run_name=run.name,
|
245
|
+
run_metadata=json.dumps(metadata.to_dict()),
|
246
|
+
)
|
247
|
+
|
248
|
+
def log_param(
|
249
|
+
self,
|
250
|
+
key: str,
|
251
|
+
value: Any,
|
252
|
+
) -> None:
|
253
|
+
"""
|
254
|
+
Log a parameter under the current run. If no run is active, this method will create a new run.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
key: The name of the parameter.
|
258
|
+
value: The value of the parameter. Values can be of any type, but will be converted to string.
|
259
|
+
"""
|
260
|
+
self.log_params({key: value})
|
261
|
+
|
262
|
+
def log_params(
|
263
|
+
self,
|
264
|
+
params: dict[str, Any],
|
265
|
+
) -> None:
|
266
|
+
"""
|
267
|
+
Log parameters under the current run. If no run is active, this method will create a new run.
|
268
|
+
|
269
|
+
Args:
|
270
|
+
params: Dictionary containing parameter keys and values. Values can be of any type, but will be converted
|
271
|
+
to string.
|
272
|
+
"""
|
273
|
+
run = self._get_or_start_run()
|
274
|
+
metadata = run._get_metadata()
|
275
|
+
for key, value in params.items():
|
276
|
+
metadata.set_param(key, value)
|
277
|
+
self._sql_client.modify_run(
|
278
|
+
experiment_name=run.experiment_name,
|
279
|
+
run_name=run.name,
|
280
|
+
run_metadata=json.dumps(metadata.to_dict()),
|
281
|
+
)
|
282
|
+
|
283
|
+
def _get_or_set_experiment(self) -> entities.Experiment:
|
284
|
+
if self._experiment:
|
285
|
+
return self._experiment
|
286
|
+
return self.set_experiment(experiment_name=DEFAULT_EXPERIMENT_NAME)
|
287
|
+
|
288
|
+
def _get_or_start_run(self) -> entities.Run:
|
289
|
+
if self._run:
|
290
|
+
return self._run
|
291
|
+
return self.start_run()
|
292
|
+
|
293
|
+
def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
|
294
|
+
generator = hrid_generator.HRID16()
|
295
|
+
existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
|
296
|
+
existing_run_names = [row[sql_client.ExperimentTrackingSQLClient.RUN_NAME_COL_NAME] for row in existing_runs]
|
297
|
+
for _ in range(1000):
|
298
|
+
run_name = generator.generate()[1]
|
299
|
+
if run_name not in existing_run_names:
|
300
|
+
return sql_identifier.SqlIdentifier(run_name)
|
301
|
+
raise RuntimeError("Random run name generation failed.")
|
302
|
+
|
303
|
+
def _print_urls(
|
304
|
+
self,
|
305
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
306
|
+
run_name: sql_identifier.SqlIdentifier,
|
307
|
+
scheme: str = "https",
|
308
|
+
host: str = "app.snowflake.com",
|
309
|
+
) -> None:
|
310
|
+
|
311
|
+
experiment_url = (
|
312
|
+
f"{scheme}://{host}/_deeplink/#/experiments"
|
313
|
+
f"/databases/{quote(str(self._database_name))}"
|
314
|
+
f"/schemas/{quote(str(self._schema_name))}"
|
315
|
+
f"/experiments/{quote(str(experiment_name))}"
|
316
|
+
)
|
317
|
+
run_url = experiment_url + f"/runs/{quote(str(run_name))}"
|
318
|
+
sys.stdout.write(f"🏃 View run {run_name} at: {run_url}\n")
|
319
|
+
sys.stdout.write(f"🧪 View experiment at: {experiment_url}\n")
|
@@ -15,7 +15,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
|
15
15
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
16
16
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
17
17
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
18
|
-
DEFAULT_IMAGE_TAG = "1.
|
18
|
+
DEFAULT_IMAGE_TAG = "1.5.0"
|
19
19
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
20
20
|
|
21
21
|
# Percent of container memory to allocate for /dev/shm volume
|
@@ -75,16 +75,75 @@ def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult
|
|
75
75
|
|
76
76
|
Returns:
|
77
77
|
A dictionary containing the execution result if available, None otherwise.
|
78
|
+
|
79
|
+
Raises:
|
80
|
+
RuntimeError: If both pickle and JSON result retrieval fail.
|
78
81
|
"""
|
79
82
|
try:
|
80
83
|
# TODO: Check if file exists
|
81
84
|
with session.file.get_stream(result_path) as result_stream:
|
82
85
|
return ExecutionResult.from_dict(pickle.load(result_stream))
|
83
|
-
except (
|
86
|
+
except (
|
87
|
+
sp_exceptions.SnowparkSQLException,
|
88
|
+
pickle.UnpicklingError,
|
89
|
+
TypeError,
|
90
|
+
ImportError,
|
91
|
+
AttributeError,
|
92
|
+
MemoryError,
|
93
|
+
) as pickle_error:
|
84
94
|
# Fall back to JSON result if loading pickled result fails for any reason
|
85
|
-
|
86
|
-
|
87
|
-
|
95
|
+
try:
|
96
|
+
result_json_path = os.path.splitext(result_path)[0] + ".json"
|
97
|
+
with session.file.get_stream(result_json_path) as result_stream:
|
98
|
+
return ExecutionResult.from_dict(json.load(result_stream))
|
99
|
+
except Exception as json_error:
|
100
|
+
# Both pickle and JSON failed - provide helpful error message
|
101
|
+
raise RuntimeError(_fetch_result_error_message(pickle_error, result_path, json_error)) from pickle_error
|
102
|
+
|
103
|
+
|
104
|
+
def _fetch_result_error_message(error: Exception, result_path: str, json_error: Optional[Exception] = None) -> str:
|
105
|
+
"""Create helpful error messages for common result retrieval failures."""
|
106
|
+
|
107
|
+
# Package import issues
|
108
|
+
if isinstance(error, ImportError):
|
109
|
+
return f"Failed to retrieve job result: Package not installed in your local environment. Error: {str(error)}"
|
110
|
+
|
111
|
+
# Package versions differ between runtime and local environment
|
112
|
+
if isinstance(error, AttributeError):
|
113
|
+
return f"Failed to retrieve job result: Package version mismatch. Error: {str(error)}"
|
114
|
+
|
115
|
+
# Serialization issues
|
116
|
+
if isinstance(error, TypeError):
|
117
|
+
return f"Failed to retrieve job result: Non-serializable objects were returned. Error: {str(error)}"
|
118
|
+
|
119
|
+
# Python version pickling incompatibility
|
120
|
+
if isinstance(error, pickle.UnpicklingError) and "protocol" in str(error).lower():
|
121
|
+
# TODO: Update this once we support different Python versions
|
122
|
+
client_version = f"Python {sys.version_info.major}.{sys.version_info.minor}"
|
123
|
+
runtime_version = "Python 3.10"
|
124
|
+
return (
|
125
|
+
f"Failed to retrieve job result: Python version mismatch - job ran on {runtime_version}, "
|
126
|
+
f"local environment using Python {client_version}. Error: {str(error)}"
|
127
|
+
)
|
128
|
+
|
129
|
+
# File access issues
|
130
|
+
if isinstance(error, sp_exceptions.SnowparkSQLException):
|
131
|
+
if "not found" in str(error).lower() or "does not exist" in str(error).lower():
|
132
|
+
return (
|
133
|
+
f"Failed to retrieve job result: No result file found. Check job.get_logs() for execution "
|
134
|
+
f"errors. Error: {str(error)}"
|
135
|
+
)
|
136
|
+
else:
|
137
|
+
return f"Failed to retrieve job result: Cannot access result file. Error: {str(error)}"
|
138
|
+
|
139
|
+
if isinstance(error, MemoryError):
|
140
|
+
return f"Failed to retrieve job result: Result too large for memory. Error: {str(error)}"
|
141
|
+
|
142
|
+
# Generic fallback
|
143
|
+
base_message = f"Failed to retrieve job result: {str(error)}"
|
144
|
+
if json_error:
|
145
|
+
base_message += f" (JSON fallback also failed: {str(json_error)})"
|
146
|
+
return base_message
|
88
147
|
|
89
148
|
|
90
149
|
def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> Exception:
|
@@ -15,6 +15,7 @@ from snowflake import snowpark
|
|
15
15
|
from snowflake.ml.jobs._utils import (
|
16
16
|
constants,
|
17
17
|
function_payload_utils,
|
18
|
+
query_helper,
|
18
19
|
stage_utils,
|
19
20
|
types,
|
20
21
|
)
|
@@ -312,14 +313,15 @@ class JobPayload:
|
|
312
313
|
stage_name = stage_path.parts[0].lstrip("@")
|
313
314
|
# Explicitly check if stage exists first since we may not have CREATE STAGE privilege
|
314
315
|
try:
|
315
|
-
|
316
|
+
query_helper.run_query(session, "describe stage identifier(?)", params=[stage_name])
|
316
317
|
except sp_exceptions.SnowparkSQLException:
|
317
|
-
|
318
|
+
query_helper.run_query(
|
319
|
+
session,
|
318
320
|
"create stage if not exists identifier(?)"
|
319
321
|
" encryption = ( type = 'SNOWFLAKE_SSE' )"
|
320
322
|
" comment = 'Created by snowflake.ml.jobs Python API'",
|
321
323
|
params=[stage_name],
|
322
|
-
)
|
324
|
+
)
|
323
325
|
|
324
326
|
# Upload payload to stage
|
325
327
|
if not isinstance(source, (Path, stage_utils.StagePath)):
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from typing import Any, Optional, Sequence
|
2
|
+
|
3
|
+
from snowflake import snowpark
|
4
|
+
from snowflake.snowpark import Row
|
5
|
+
from snowflake.snowpark._internal import utils
|
6
|
+
from snowflake.snowpark._internal.analyzer import snowflake_plan
|
7
|
+
|
8
|
+
|
9
|
+
def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> list[Row]:
|
10
|
+
metadata = session._conn._cursor.description
|
11
|
+
result_set = result["data"]
|
12
|
+
return utils.result_set_to_rows(result_set, metadata)
|
13
|
+
|
14
|
+
|
15
|
+
@snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
|
16
|
+
def run_query(session: snowpark.Session, query_text: str, params: Optional[Sequence[Any]] = None) -> list[Row]:
|
17
|
+
result = session._conn.run_query(query=query_text, params=params, _force_qmark_paramstyle=True)
|
18
|
+
if not isinstance(result, dict) or "data" not in result:
|
19
|
+
raise ValueError(f"Unprocessable result: {result}")
|
20
|
+
return result_set_to_rows(session, result)
|
@@ -16,9 +16,13 @@ import cloudpickle
|
|
16
16
|
from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
|
17
17
|
|
18
18
|
from snowflake.ml.jobs._utils import constants
|
19
|
-
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
20
19
|
from snowflake.snowpark import Session
|
21
20
|
|
21
|
+
try:
|
22
|
+
from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
|
23
|
+
except ImportError:
|
24
|
+
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
25
|
+
|
22
26
|
# Configure logging
|
23
27
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
24
28
|
logger = logging.getLogger(__name__)
|
@@ -6,13 +6,17 @@ from typing import Any, Optional, Union
|
|
6
6
|
|
7
7
|
from snowflake import snowpark
|
8
8
|
from snowflake.ml._internal.utils import snowflake_env
|
9
|
-
from snowflake.ml.jobs._utils import constants, types
|
9
|
+
from snowflake.ml.jobs._utils import constants, query_helper, types
|
10
10
|
|
11
11
|
|
12
12
|
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
13
13
|
"""Extract resource information for the specified compute pool"""
|
14
14
|
# Get the instance family
|
15
|
-
rows =
|
15
|
+
rows = query_helper.run_query(
|
16
|
+
session,
|
17
|
+
"show compute pools like ?",
|
18
|
+
params=[compute_pool],
|
19
|
+
)
|
16
20
|
if not rows:
|
17
21
|
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
18
22
|
instance_family: str = rows[0]["instance_family"]
|
@@ -180,7 +184,7 @@ def generate_service_spec(
|
|
180
184
|
constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
|
181
185
|
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
182
186
|
}
|
183
|
-
endpoints = []
|
187
|
+
endpoints: list[dict[str, Any]] = []
|
184
188
|
|
185
189
|
if target_instances > 1:
|
186
190
|
# Update environment variables for multi-node job
|
@@ -189,7 +193,7 @@ def generate_service_spec(
|
|
189
193
|
env_vars[constants.MIN_INSTANCES_ENV_VAR] = str(min_instances)
|
190
194
|
|
191
195
|
# Define Ray endpoints for intra-service instance communication
|
192
|
-
ray_endpoints = [
|
196
|
+
ray_endpoints: list[dict[str, Any]] = [
|
193
197
|
{"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
|
194
198
|
{"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
|
195
199
|
{"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
|
@@ -232,6 +236,19 @@ def generate_service_spec(
|
|
232
236
|
],
|
233
237
|
"volumes": volumes,
|
234
238
|
}
|
239
|
+
|
240
|
+
if target_instances > 1:
|
241
|
+
spec_dict.update(
|
242
|
+
{
|
243
|
+
"resourceManagement": {
|
244
|
+
"controlPolicy": {
|
245
|
+
"startupOrder": {
|
246
|
+
"type": "FirstInstance",
|
247
|
+
},
|
248
|
+
},
|
249
|
+
},
|
250
|
+
}
|
251
|
+
)
|
235
252
|
if endpoints:
|
236
253
|
spec_dict["endpoints"] = endpoints
|
237
254
|
if metrics:
|