snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/identifier.py +1 -1
  5. snowflake/ml/_internal/utils/mixins.py +71 -0
  6. snowflake/ml/_internal/utils/service_logger.py +4 -2
  7. snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
  8. snowflake/ml/data/data_connector.py +43 -2
  9. snowflake/ml/data/data_ingestor.py +8 -0
  10. snowflake/ml/data/torch_utils.py +1 -1
  11. snowflake/ml/dataset/dataset.py +3 -2
  12. snowflake/ml/dataset/dataset_reader.py +22 -6
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/experiment_tracking.py +319 -0
  20. snowflake/ml/jobs/_utils/constants.py +1 -1
  21. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +5 -3
  23. snowflake/ml/jobs/_utils/query_helper.py +20 -0
  24. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
  25. snowflake/ml/jobs/_utils/spec_utils.py +21 -4
  26. snowflake/ml/jobs/decorators.py +18 -25
  27. snowflake/ml/jobs/job.py +137 -37
  28. snowflake/ml/jobs/manager.py +228 -153
  29. snowflake/ml/lineage/lineage_node.py +2 -2
  30. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  31. snowflake/ml/model/_client/ops/model_ops.py +12 -3
  32. snowflake/ml/model/_client/ops/service_ops.py +324 -138
  33. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  34. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  35. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  38. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  39. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  40. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  45. snowflake/ml/model/event_handler.py +117 -0
  46. snowflake/ml/model/model_signature.py +9 -9
  47. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  48. snowflake/ml/model/target_platform.py +11 -0
  49. snowflake/ml/model/task.py +9 -0
  50. snowflake/ml/model/type_hints.py +5 -13
  51. snowflake/ml/modeling/framework/base.py +1 -1
  52. snowflake/ml/modeling/metrics/classification.py +14 -14
  53. snowflake/ml/modeling/metrics/correlation.py +19 -8
  54. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  55. snowflake/ml/modeling/metrics/ranking.py +6 -6
  56. snowflake/ml/modeling/metrics/regression.py +9 -9
  57. snowflake/ml/monitoring/explain_visualize.py +12 -5
  58. snowflake/ml/registry/_manager/model_manager.py +47 -15
  59. snowflake/ml/registry/registry.py +109 -64
  60. snowflake/ml/version.py +1 -1
  61. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
  62. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
  63. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
  64. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
  65. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,68 @@
1
+ import dataclasses
2
+ import enum
3
+ import typing
4
+
5
+
6
+ class RunStatus(str, enum.Enum):
7
+ UNKNOWN = "UNKNOWN"
8
+ RUNNING = "RUNNING"
9
+ FINISHED = "FINISHED"
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class Metric:
14
+ name: str
15
+ value: float
16
+ step: int
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Param:
21
+ name: str
22
+ value: str
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class RunMetadata:
27
+ status: RunStatus
28
+ metrics: list[Metric]
29
+ parameters: list[Param]
30
+
31
+ @classmethod
32
+ def from_dict(
33
+ cls,
34
+ metadata: dict, # type: ignore[type-arg]
35
+ ) -> "RunMetadata":
36
+ return RunMetadata(
37
+ status=RunStatus(metadata.get("status", RunStatus.UNKNOWN.value)),
38
+ metrics=[Metric(**m) for m in metadata.get("metrics", [])],
39
+ parameters=[Param(**p) for p in metadata.get("parameters", [])],
40
+ )
41
+
42
+ def to_dict(self) -> dict: # type: ignore[type-arg]
43
+ return dataclasses.asdict(self)
44
+
45
+ def set_metric(
46
+ self,
47
+ key: str,
48
+ value: float,
49
+ step: int,
50
+ ) -> None:
51
+ for metric in self.metrics:
52
+ if metric.name == key and metric.step == step:
53
+ metric.value = value
54
+ break
55
+ else:
56
+ self.metrics.append(Metric(name=key, value=value, step=step))
57
+
58
+ def set_param(
59
+ self,
60
+ key: str,
61
+ value: typing.Any,
62
+ ) -> None:
63
+ for parameter in self.parameters:
64
+ if parameter.name == key:
65
+ parameter.value = str(value)
66
+ break
67
+ else:
68
+ self.parameters.append(Param(name=key, value=str(value)))
@@ -0,0 +1,63 @@
1
+ import dataclasses
2
+ import functools
3
+ import types
4
+ from typing import Callable, Optional
5
+
6
+ from snowflake.ml import model
7
+ from snowflake.ml.registry._manager import model_manager
8
+
9
+
10
+ @dataclasses.dataclass(frozen=True)
11
+ class ExperimentInfo:
12
+ """Serializable information identifying a Experiment"""
13
+
14
+ fully_qualified_name: str
15
+ run_name: str
16
+
17
+
18
+ class ExperimentInfoPatcher:
19
+ """Context manager that patches ModelManager.log_model to include experiment information.
20
+
21
+ This class maintains a stack of active experiment contexts and ensures that
22
+ log_model calls are automatically tagged with the appropriate experiment info.
23
+ """
24
+
25
+ # Store original method at class definition time to avoid recursive patching
26
+ _original_log_model: Callable[..., model.ModelVersion] = model_manager.ModelManager.log_model
27
+
28
+ # Stack of active experiment_info contexts for nested experiment support
29
+ _experiment_info_stack: list[ExperimentInfo] = []
30
+
31
+ def __init__(self, experiment_info: ExperimentInfo) -> None:
32
+ self._experiment_info = experiment_info
33
+
34
+ def __enter__(self) -> "ExperimentInfoPatcher":
35
+ # Only patch ModelManager.log_model if we're the first patcher to avoid nested patching
36
+ if not ExperimentInfoPatcher._experiment_info_stack:
37
+
38
+ @functools.wraps(ExperimentInfoPatcher._original_log_model)
39
+ def patched(*args, **kwargs) -> model.ModelVersion: # type: ignore[no-untyped-def]
40
+ # Use the most recent (top of stack) experiment_info for nested contexts
41
+ current_experiment_info = ExperimentInfoPatcher._experiment_info_stack[-1]
42
+ return ExperimentInfoPatcher._original_log_model(
43
+ *args, **kwargs, experiment_info=current_experiment_info
44
+ )
45
+
46
+ model_manager.ModelManager.log_model = patched # type: ignore[method-assign]
47
+
48
+ ExperimentInfoPatcher._experiment_info_stack.append(self._experiment_info)
49
+ return self
50
+
51
+ def __exit__(
52
+ self,
53
+ exc_type: Optional[type[BaseException]],
54
+ exc_value: Optional[BaseException],
55
+ traceback: Optional[types.TracebackType],
56
+ ) -> None:
57
+ ExperimentInfoPatcher._experiment_info_stack.pop()
58
+
59
+ # Restore original method when no patches are active to clean up properly
60
+ if not ExperimentInfoPatcher._experiment_info_stack:
61
+ model_manager.ModelManager.log_model = ( # type: ignore[method-assign]
62
+ ExperimentInfoPatcher._original_log_model
63
+ )
@@ -0,0 +1,319 @@
1
+ import functools
2
+ import json
3
+ import sys
4
+ from typing import Any, Optional, Union
5
+ from urllib.parse import quote
6
+
7
+ import snowflake.snowpark._internal.utils as snowpark_utils
8
+ from snowflake.ml import model, registry
9
+ from snowflake.ml._internal.human_readable_id import hrid_generator
10
+ from snowflake.ml._internal.utils import sql_identifier
11
+ from snowflake.ml.experiment import (
12
+ _entities as entities,
13
+ _experiment_info as experiment_info,
14
+ )
15
+ from snowflake.ml.experiment._client import experiment_tracking_sql_client as sql_client
16
+ from snowflake.ml.model import type_hints
17
+ from snowflake.ml.utils import sql_client as sql_client_utils
18
+ from snowflake.snowpark import session
19
+
20
+ DEFAULT_EXPERIMENT_NAME = sql_identifier.SqlIdentifier("DEFAULT")
21
+
22
+
23
+ class ExperimentTracking:
24
+ """
25
+ Class to manage experiments in Snowflake.
26
+ """
27
+
28
+ @snowpark_utils.private_preview(version="1.9.1")
29
+ def __init__(
30
+ self,
31
+ session: session.Session,
32
+ *,
33
+ database_name: Optional[str] = None,
34
+ schema_name: Optional[str] = None,
35
+ ) -> None:
36
+ """
37
+ Initializes experiment tracking within a pre-created schema.
38
+
39
+ Args:
40
+ session: The Snowpark Session to connect with Snowflake.
41
+ database_name: The name of the database. If None, the current database of the session
42
+ will be used. Defaults to None.
43
+ schema_name: The name of the schema. If None, the current schema of the session
44
+ will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.
45
+
46
+ Raises:
47
+ ValueError: If no database is provided and no active database exists in the session.
48
+ """
49
+ if database_name:
50
+ self._database_name = sql_identifier.SqlIdentifier(database_name)
51
+ elif session_db := session.get_current_database():
52
+ self._database_name = sql_identifier.SqlIdentifier(session_db)
53
+ else:
54
+ raise ValueError("You need to provide a database to use experiment tracking.")
55
+
56
+ if schema_name:
57
+ self._schema_name = sql_identifier.SqlIdentifier(schema_name)
58
+ elif session_schema := session.get_current_schema():
59
+ self._schema_name = sql_identifier.SqlIdentifier(session_schema)
60
+ else:
61
+ self._schema_name = sql_identifier.SqlIdentifier("PUBLIC")
62
+
63
+ self._sql_client = sql_client.ExperimentTrackingSQLClient(
64
+ session,
65
+ database_name=self._database_name,
66
+ schema_name=self._schema_name,
67
+ )
68
+ self._registry = registry.Registry(
69
+ session=session,
70
+ database_name=self._database_name,
71
+ schema_name=self._schema_name,
72
+ )
73
+
74
+ # The experiment in context
75
+ self._experiment: Optional[entities.Experiment] = None
76
+ # The run in context
77
+ self._run: Optional[entities.Run] = None
78
+
79
+ def set_experiment(
80
+ self,
81
+ experiment_name: str,
82
+ ) -> entities.Experiment:
83
+ """
84
+ Set the experiment in context. Creates a new experiment if it doesn't exist.
85
+
86
+ Args:
87
+ experiment_name: The name of the experiment.
88
+
89
+ Returns:
90
+ Experiment: The experiment that was set.
91
+ """
92
+ experiment_name = sql_identifier.SqlIdentifier(experiment_name)
93
+ if self._experiment and self._experiment.name == experiment_name:
94
+ return self._experiment
95
+ self._sql_client.create_experiment(
96
+ experiment_name=experiment_name,
97
+ creation_mode=sql_client_utils.CreationMode(if_not_exists=True),
98
+ )
99
+ self._experiment = entities.Experiment(experiment_name=experiment_name)
100
+ self._run = None
101
+ return self._experiment
102
+
103
+ def delete_experiment(
104
+ self,
105
+ experiment_name: str,
106
+ ) -> None:
107
+ """
108
+ Delete an experiment.
109
+
110
+ Args:
111
+ experiment_name: The name of the experiment.
112
+ """
113
+ self._sql_client.drop_experiment(experiment_name=sql_identifier.SqlIdentifier(experiment_name))
114
+ if self._experiment and self._experiment.name == experiment_name:
115
+ self._experiment = None
116
+ self._run = None
117
+
118
+ @functools.wraps(registry.Registry.log_model)
119
+ def log_model(
120
+ self,
121
+ model: Union[type_hints.SupportedModelType, model.ModelVersion],
122
+ *,
123
+ model_name: str,
124
+ **kwargs: Any,
125
+ ) -> model.ModelVersion:
126
+ run = self._get_or_start_run()
127
+ with experiment_info.ExperimentInfoPatcher(experiment_info=run._get_experiment_info()):
128
+ return self._registry.log_model(model, model_name=model_name, **kwargs)
129
+
130
+ def start_run(
131
+ self,
132
+ run_name: Optional[str] = None,
133
+ ) -> entities.Run:
134
+ """
135
+ Start a new run.
136
+
137
+ Args:
138
+ run_name: The name of the run. If None, a default name will be generated.
139
+
140
+ Returns:
141
+ Run: The run that was started.
142
+
143
+ Raises:
144
+ RuntimeError: If a run is already active.
145
+ """
146
+ if self._run:
147
+ raise RuntimeError("A run is already active. Please end the current run before starting a new one.")
148
+ experiment = self._get_or_set_experiment()
149
+ run_name = (
150
+ sql_identifier.SqlIdentifier(run_name) if run_name is not None else self._generate_run_name(experiment)
151
+ )
152
+ self._sql_client.add_run(
153
+ experiment_name=experiment.name,
154
+ run_name=run_name,
155
+ )
156
+ self._run = entities.Run(experiment_tracking=self, experiment_name=experiment.name, run_name=run_name)
157
+ return self._run
158
+
159
+ def end_run(self, run_name: Optional[str] = None) -> None:
160
+ """
161
+ End the current run if no run name is provided. Otherwise, the specified run is ended.
162
+
163
+ Args:
164
+ run_name: The name of the run to be ended. If None, the current run is ended.
165
+
166
+ Raises:
167
+ RuntimeError: If no run is active.
168
+ """
169
+ if not self._experiment:
170
+ raise RuntimeError("No experiment set. Please set an experiment before ending a run.")
171
+ experiment_name = self._experiment.name
172
+
173
+ if run_name:
174
+ run_name = sql_identifier.SqlIdentifier(run_name)
175
+ elif self._run:
176
+ run_name = self._run.name
177
+ else:
178
+ raise RuntimeError("No run is active. Please start a run before ending it.")
179
+
180
+ self._sql_client.commit_run(
181
+ experiment_name=experiment_name,
182
+ run_name=run_name,
183
+ )
184
+ if self._run and run_name == self._run.name:
185
+ self._run = None
186
+ self._print_urls(experiment_name=experiment_name, run_name=run_name)
187
+
188
+ def delete_run(
189
+ self,
190
+ run_name: str,
191
+ ) -> None:
192
+ """
193
+ Delete a run.
194
+
195
+ Args:
196
+ run_name: The name of the run to be deleted.
197
+
198
+ Raises:
199
+ RuntimeError: If no experiment is set.
200
+ """
201
+ if not self._experiment:
202
+ raise RuntimeError("No experiment set. Please set an experiment before deleting a run.")
203
+ self._sql_client.drop_run(
204
+ experiment_name=self._experiment.name,
205
+ run_name=sql_identifier.SqlIdentifier(run_name),
206
+ )
207
+ if self._run and self._run.name == run_name:
208
+ self._run = None
209
+
210
+ def log_metric(
211
+ self,
212
+ key: str,
213
+ value: float,
214
+ step: int = 0,
215
+ ) -> None:
216
+ """
217
+ Log a metric under the current run. If no run is active, this method will create a new run.
218
+
219
+ Args:
220
+ key: The name of the metric.
221
+ value: The value of the metric.
222
+ step: The step of the metric. Defaults to 0.
223
+ """
224
+ self.log_metrics(metrics={key: value}, step=step)
225
+
226
+ def log_metrics(
227
+ self,
228
+ metrics: dict[str, float],
229
+ step: int = 0,
230
+ ) -> None:
231
+ """
232
+ Log metrics under the current run. If no run is active, this method will create a new run.
233
+
234
+ Args:
235
+ metrics: Dictionary containing metric keys and float values.
236
+ step: The step of the metrics. Defaults to 0.
237
+ """
238
+ run = self._get_or_start_run()
239
+ metadata = run._get_metadata()
240
+ for key, value in metrics.items():
241
+ metadata.set_metric(key, value, step)
242
+ self._sql_client.modify_run(
243
+ experiment_name=run.experiment_name,
244
+ run_name=run.name,
245
+ run_metadata=json.dumps(metadata.to_dict()),
246
+ )
247
+
248
+ def log_param(
249
+ self,
250
+ key: str,
251
+ value: Any,
252
+ ) -> None:
253
+ """
254
+ Log a parameter under the current run. If no run is active, this method will create a new run.
255
+
256
+ Args:
257
+ key: The name of the parameter.
258
+ value: The value of the parameter. Values can be of any type, but will be converted to string.
259
+ """
260
+ self.log_params({key: value})
261
+
262
+ def log_params(
263
+ self,
264
+ params: dict[str, Any],
265
+ ) -> None:
266
+ """
267
+ Log parameters under the current run. If no run is active, this method will create a new run.
268
+
269
+ Args:
270
+ params: Dictionary containing parameter keys and values. Values can be of any type, but will be converted
271
+ to string.
272
+ """
273
+ run = self._get_or_start_run()
274
+ metadata = run._get_metadata()
275
+ for key, value in params.items():
276
+ metadata.set_param(key, value)
277
+ self._sql_client.modify_run(
278
+ experiment_name=run.experiment_name,
279
+ run_name=run.name,
280
+ run_metadata=json.dumps(metadata.to_dict()),
281
+ )
282
+
283
+ def _get_or_set_experiment(self) -> entities.Experiment:
284
+ if self._experiment:
285
+ return self._experiment
286
+ return self.set_experiment(experiment_name=DEFAULT_EXPERIMENT_NAME)
287
+
288
+ def _get_or_start_run(self) -> entities.Run:
289
+ if self._run:
290
+ return self._run
291
+ return self.start_run()
292
+
293
+ def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
294
+ generator = hrid_generator.HRID16()
295
+ existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
296
+ existing_run_names = [row[sql_client.ExperimentTrackingSQLClient.RUN_NAME_COL_NAME] for row in existing_runs]
297
+ for _ in range(1000):
298
+ run_name = generator.generate()[1]
299
+ if run_name not in existing_run_names:
300
+ return sql_identifier.SqlIdentifier(run_name)
301
+ raise RuntimeError("Random run name generation failed.")
302
+
303
+ def _print_urls(
304
+ self,
305
+ experiment_name: sql_identifier.SqlIdentifier,
306
+ run_name: sql_identifier.SqlIdentifier,
307
+ scheme: str = "https",
308
+ host: str = "app.snowflake.com",
309
+ ) -> None:
310
+
311
+ experiment_url = (
312
+ f"{scheme}://{host}/_deeplink/#/experiments"
313
+ f"/databases/{quote(str(self._database_name))}"
314
+ f"/schemas/{quote(str(self._schema_name))}"
315
+ f"/experiments/{quote(str(experiment_name))}"
316
+ )
317
+ run_url = experiment_url + f"/runs/{quote(str(run_name))}"
318
+ sys.stdout.write(f"🏃 View run {run_name} at: {run_url}\n")
319
+ sys.stdout.write(f"🧪 View experiment at: {experiment_url}\n")
@@ -15,7 +15,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
15
15
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
16
16
  DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
17
17
  DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
18
- DEFAULT_IMAGE_TAG = "1.4.2"
18
+ DEFAULT_IMAGE_TAG = "1.5.0"
19
19
  DEFAULT_ENTRYPOINT_PATH = "func.py"
20
20
 
21
21
  # Percent of container memory to allocate for /dev/shm volume
@@ -75,16 +75,75 @@ def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult
75
75
 
76
76
  Returns:
77
77
  A dictionary containing the execution result if available, None otherwise.
78
+
79
+ Raises:
80
+ RuntimeError: If both pickle and JSON result retrieval fail.
78
81
  """
79
82
  try:
80
83
  # TODO: Check if file exists
81
84
  with session.file.get_stream(result_path) as result_stream:
82
85
  return ExecutionResult.from_dict(pickle.load(result_stream))
83
- except (sp_exceptions.SnowparkSQLException, pickle.UnpicklingError, TypeError, ImportError):
86
+ except (
87
+ sp_exceptions.SnowparkSQLException,
88
+ pickle.UnpicklingError,
89
+ TypeError,
90
+ ImportError,
91
+ AttributeError,
92
+ MemoryError,
93
+ ) as pickle_error:
84
94
  # Fall back to JSON result if loading pickled result fails for any reason
85
- result_json_path = os.path.splitext(result_path)[0] + ".json"
86
- with session.file.get_stream(result_json_path) as result_stream:
87
- return ExecutionResult.from_dict(json.load(result_stream))
95
+ try:
96
+ result_json_path = os.path.splitext(result_path)[0] + ".json"
97
+ with session.file.get_stream(result_json_path) as result_stream:
98
+ return ExecutionResult.from_dict(json.load(result_stream))
99
+ except Exception as json_error:
100
+ # Both pickle and JSON failed - provide helpful error message
101
+ raise RuntimeError(_fetch_result_error_message(pickle_error, result_path, json_error)) from pickle_error
102
+
103
+
104
+ def _fetch_result_error_message(error: Exception, result_path: str, json_error: Optional[Exception] = None) -> str:
105
+ """Create helpful error messages for common result retrieval failures."""
106
+
107
+ # Package import issues
108
+ if isinstance(error, ImportError):
109
+ return f"Failed to retrieve job result: Package not installed in your local environment. Error: {str(error)}"
110
+
111
+ # Package versions differ between runtime and local environment
112
+ if isinstance(error, AttributeError):
113
+ return f"Failed to retrieve job result: Package version mismatch. Error: {str(error)}"
114
+
115
+ # Serialization issues
116
+ if isinstance(error, TypeError):
117
+ return f"Failed to retrieve job result: Non-serializable objects were returned. Error: {str(error)}"
118
+
119
+ # Python version pickling incompatibility
120
+ if isinstance(error, pickle.UnpicklingError) and "protocol" in str(error).lower():
121
+ # TODO: Update this once we support different Python versions
122
+ client_version = f"Python {sys.version_info.major}.{sys.version_info.minor}"
123
+ runtime_version = "Python 3.10"
124
+ return (
125
+ f"Failed to retrieve job result: Python version mismatch - job ran on {runtime_version}, "
126
+ f"local environment using Python {client_version}. Error: {str(error)}"
127
+ )
128
+
129
+ # File access issues
130
+ if isinstance(error, sp_exceptions.SnowparkSQLException):
131
+ if "not found" in str(error).lower() or "does not exist" in str(error).lower():
132
+ return (
133
+ f"Failed to retrieve job result: No result file found. Check job.get_logs() for execution "
134
+ f"errors. Error: {str(error)}"
135
+ )
136
+ else:
137
+ return f"Failed to retrieve job result: Cannot access result file. Error: {str(error)}"
138
+
139
+ if isinstance(error, MemoryError):
140
+ return f"Failed to retrieve job result: Result too large for memory. Error: {str(error)}"
141
+
142
+ # Generic fallback
143
+ base_message = f"Failed to retrieve job result: {str(error)}"
144
+ if json_error:
145
+ base_message += f" (JSON fallback also failed: {str(json_error)})"
146
+ return base_message
88
147
 
89
148
 
90
149
  def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> Exception:
@@ -15,6 +15,7 @@ from snowflake import snowpark
15
15
  from snowflake.ml.jobs._utils import (
16
16
  constants,
17
17
  function_payload_utils,
18
+ query_helper,
18
19
  stage_utils,
19
20
  types,
20
21
  )
@@ -312,14 +313,15 @@ class JobPayload:
312
313
  stage_name = stage_path.parts[0].lstrip("@")
313
314
  # Explicitly check if stage exists first since we may not have CREATE STAGE privilege
314
315
  try:
315
- session.sql("describe stage identifier(?)", params=[stage_name]).collect()
316
+ query_helper.run_query(session, "describe stage identifier(?)", params=[stage_name])
316
317
  except sp_exceptions.SnowparkSQLException:
317
- session.sql(
318
+ query_helper.run_query(
319
+ session,
318
320
  "create stage if not exists identifier(?)"
319
321
  " encryption = ( type = 'SNOWFLAKE_SSE' )"
320
322
  " comment = 'Created by snowflake.ml.jobs Python API'",
321
323
  params=[stage_name],
322
- ).collect()
324
+ )
323
325
 
324
326
  # Upload payload to stage
325
327
  if not isinstance(source, (Path, stage_utils.StagePath)):
@@ -0,0 +1,20 @@
1
+ from typing import Any, Optional, Sequence
2
+
3
+ from snowflake import snowpark
4
+ from snowflake.snowpark import Row
5
+ from snowflake.snowpark._internal import utils
6
+ from snowflake.snowpark._internal.analyzer import snowflake_plan
7
+
8
+
9
+ def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> list[Row]:
10
+ metadata = session._conn._cursor.description
11
+ result_set = result["data"]
12
+ return utils.result_set_to_rows(result_set, metadata)
13
+
14
+
15
+ @snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
16
+ def run_query(session: snowpark.Session, query_text: str, params: Optional[Sequence[Any]] = None) -> list[Row]:
17
+ result = session._conn.run_query(query=query_text, params=params, _force_qmark_paramstyle=True)
18
+ if not isinstance(result, dict) or "data" not in result:
19
+ raise ValueError(f"Unprocessable result: {result}")
20
+ return result_set_to_rows(session, result)
@@ -16,9 +16,13 @@ import cloudpickle
16
16
  from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
17
17
 
18
18
  from snowflake.ml.jobs._utils import constants
19
- from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
20
19
  from snowflake.snowpark import Session
21
20
 
21
+ try:
22
+ from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
23
+ except ImportError:
24
+ from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
25
+
22
26
  # Configure logging
23
27
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
24
28
  logger = logging.getLogger(__name__)
@@ -6,13 +6,17 @@ from typing import Any, Optional, Union
6
6
 
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml._internal.utils import snowflake_env
9
- from snowflake.ml.jobs._utils import constants, types
9
+ from snowflake.ml.jobs._utils import constants, query_helper, types
10
10
 
11
11
 
12
12
  def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
13
13
  """Extract resource information for the specified compute pool"""
14
14
  # Get the instance family
15
- rows = session.sql("show compute pools like ?", params=[compute_pool]).collect()
15
+ rows = query_helper.run_query(
16
+ session,
17
+ "show compute pools like ?",
18
+ params=[compute_pool],
19
+ )
16
20
  if not rows:
17
21
  raise ValueError(f"Compute pool '{compute_pool}' not found")
18
22
  instance_family: str = rows[0]["instance_family"]
@@ -180,7 +184,7 @@ def generate_service_spec(
180
184
  constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
181
185
  constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
182
186
  }
183
- endpoints = []
187
+ endpoints: list[dict[str, Any]] = []
184
188
 
185
189
  if target_instances > 1:
186
190
  # Update environment variables for multi-node job
@@ -189,7 +193,7 @@ def generate_service_spec(
189
193
  env_vars[constants.MIN_INSTANCES_ENV_VAR] = str(min_instances)
190
194
 
191
195
  # Define Ray endpoints for intra-service instance communication
192
- ray_endpoints = [
196
+ ray_endpoints: list[dict[str, Any]] = [
193
197
  {"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
194
198
  {"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
195
199
  {"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
@@ -232,6 +236,19 @@ def generate_service_spec(
232
236
  ],
233
237
  "volumes": volumes,
234
238
  }
239
+
240
+ if target_instances > 1:
241
+ spec_dict.update(
242
+ {
243
+ "resourceManagement": {
244
+ "controlPolicy": {
245
+ "startupOrder": {
246
+ "type": "FirstInstance",
247
+ },
248
+ },
249
+ },
250
+ }
251
+ )
235
252
  if endpoints:
236
253
  spec_dict["endpoints"] = endpoints
237
254
  if metrics: