snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +54 -42
  5. snowflake/ml/_internal/utils/service_logger.py +105 -3
  6. snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
  7. snowflake/ml/data/data_connector.py +13 -2
  8. snowflake/ml/data/data_ingestor.py +8 -0
  9. snowflake/ml/data/torch_utils.py +1 -1
  10. snowflake/ml/dataset/dataset.py +2 -1
  11. snowflake/ml/dataset/dataset_reader.py +14 -4
  12. snowflake/ml/experiment/__init__.py +3 -0
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/callback.py +121 -0
  20. snowflake/ml/experiment/experiment_tracking.py +319 -0
  21. snowflake/ml/jobs/_utils/constants.py +15 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +156 -54
  23. snowflake/ml/jobs/_utils/query_helper.py +16 -5
  24. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  25. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
  26. snowflake/ml/jobs/_utils/spec_utils.py +23 -8
  27. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  28. snowflake/ml/jobs/_utils/types.py +64 -4
  29. snowflake/ml/jobs/job.py +70 -75
  30. snowflake/ml/jobs/manager.py +59 -31
  31. snowflake/ml/lineage/lineage_node.py +2 -2
  32. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  33. snowflake/ml/model/_client/ops/service_ops.py +336 -137
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
  36. snowflake/ml/model/_client/sql/service.py +1 -38
  37. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  38. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
  39. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  41. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  45. snowflake/ml/model/_signatures/utils.py +4 -0
  46. snowflake/ml/model/event_handler.py +117 -0
  47. snowflake/ml/model/model_signature.py +11 -9
  48. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  49. snowflake/ml/modeling/framework/base.py +1 -1
  50. snowflake/ml/modeling/metrics/classification.py +14 -14
  51. snowflake/ml/modeling/metrics/correlation.py +19 -8
  52. snowflake/ml/modeling/metrics/ranking.py +6 -6
  53. snowflake/ml/modeling/metrics/regression.py +9 -9
  54. snowflake/ml/monitoring/explain_visualize.py +12 -5
  55. snowflake/ml/registry/_manager/model_manager.py +32 -15
  56. snowflake/ml/registry/registry.py +48 -80
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
  59. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
  60. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
  61. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
@@ -12,12 +12,14 @@ from snowflake.ml.jobs._utils import constants, query_helper, types
12
12
  def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
13
13
  """Extract resource information for the specified compute pool"""
14
14
  # Get the instance family
15
- rows = session._conn.run_query("show compute pools like ?", params=[compute_pool], _force_qmark_paramstyle=True)
16
- if not rows or not isinstance(rows, dict) or not rows.get("data"):
15
+ rows = query_helper.run_query(
16
+ session,
17
+ "show compute pools like ?",
18
+ params=[compute_pool],
19
+ )
20
+ if not rows:
17
21
  raise ValueError(f"Compute pool '{compute_pool}' not found")
18
- requested_attributes = query_helper.get_attribute_map(session, {"instance_family": 4})
19
- compute_pool_info = rows["data"]
20
- instance_family: str = compute_pool_info[0][requested_attributes["instance_family"]]
22
+ instance_family: str = rows[0]["instance_family"]
21
23
  cloud = snowflake_env.get_current_cloud(session, default=snowflake_env.SnowflakeCloudType.AWS)
22
24
 
23
25
  return (
@@ -179,10 +181,10 @@ def generate_service_spec(
179
181
  # TODO: Add hooks for endpoints for integration with TensorBoard etc
180
182
 
181
183
  env_vars = {
182
- constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
184
+ constants.PAYLOAD_DIR_ENV_VAR: constants.APP_MOUNT_PATH,
183
185
  constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
184
186
  }
185
- endpoints = []
187
+ endpoints: list[dict[str, Any]] = []
186
188
 
187
189
  if target_instances > 1:
188
190
  # Update environment variables for multi-node job
@@ -191,7 +193,7 @@ def generate_service_spec(
191
193
  env_vars[constants.MIN_INSTANCES_ENV_VAR] = str(min_instances)
192
194
 
193
195
  # Define Ray endpoints for intra-service instance communication
194
- ray_endpoints = [
196
+ ray_endpoints: list[dict[str, Any]] = [
195
197
  {"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
196
198
  {"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
197
199
  {"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
@@ -234,6 +236,19 @@ def generate_service_spec(
234
236
  ],
235
237
  "volumes": volumes,
236
238
  }
239
+
240
+ if target_instances > 1:
241
+ spec_dict.update(
242
+ {
243
+ "resourceManagement": {
244
+ "controlPolicy": {
245
+ "startupOrder": {
246
+ "type": "FirstInstance",
247
+ },
248
+ },
249
+ },
250
+ }
251
+ )
237
252
  if endpoints:
238
253
  spec_dict["endpoints"] = endpoints
239
254
  if metrics:
@@ -14,7 +14,10 @@ _SNOWURL_PATH_RE = re.compile(
14
14
  r"(?P<path>versions(?:/(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)$"
15
15
  )
16
16
 
17
- _STAGEF_PATH_RE = re.compile(r"^@(?P<stage>~|%?\w+)(?:/(?P<relpath>[\w\-./]*))?$")
17
+ # Break long regex into two main parts
18
+ _STAGE_PATTERN = rf"~|%?(?:(?:{identifier._SF_IDENTIFIER}\.?){{,2}}{identifier._SF_IDENTIFIER})"
19
+ _RELPATH_PATTERN = r"[\w\-./]*"
20
+ _STAGEF_PATH_RE = re.compile(rf"^@(?P<stage>{_STAGE_PATTERN})(?:/(?P<relpath>{_RELPATH_PATTERN}))?$")
18
21
 
19
22
 
20
23
  class StagePath:
@@ -29,6 +32,14 @@ class StagePath:
29
32
  self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
30
33
  self._path = Path(relpath or "")
31
34
 
35
+ @property
36
+ def parts(self) -> tuple[str, ...]:
37
+ return self._path.parts
38
+
39
+ @property
40
+ def name(self) -> str:
41
+ return self._path.name
42
+
32
43
  @property
33
44
  def parent(self) -> "StagePath":
34
45
  if self._path.parent == Path(""):
@@ -51,18 +62,28 @@ class StagePath:
51
62
  else:
52
63
  return f"{self.root}/{path}"
53
64
 
54
- def is_relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> bool:
65
+ def is_relative_to(self, *other: Union[str, os.PathLike[str]]) -> bool:
66
+ if not other:
67
+ raise TypeError("is_relative_to() requires at least one argument")
68
+ # For now, we only support a single argument, like pathlib.Path in Python < 3.12
69
+ path = other[0]
55
70
  stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
56
71
  if stage_path.root == self.root:
57
72
  return self._path.is_relative_to(stage_path._path)
58
73
  else:
59
74
  return False
60
75
 
61
- def relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> PurePath:
76
+ def relative_to(self, *other: Union[str, os.PathLike[str]]) -> PurePath:
77
+ if not other:
78
+ raise TypeError("relative_to() requires at least one argument")
79
+ if not self.is_relative_to(*other):
80
+ raise ValueError(f"{other} does not start with {self._raw_path}")
81
+ path = other[0]
62
82
  stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
63
83
  if self.root == stage_path.root:
64
84
  return self._path.relative_to(stage_path._path)
65
- raise ValueError(f"{self._raw_path} does not start with {stage_path._raw_path}")
85
+ else:
86
+ raise ValueError(f"{self._raw_path} does not start with {stage_path._raw_path}")
66
87
 
67
88
  def absolute(self) -> "StagePath":
68
89
  return self
@@ -88,6 +109,9 @@ class StagePath:
88
109
  def __str__(self) -> str:
89
110
  return self.as_posix()
90
111
 
112
+ def __repr__(self) -> str:
113
+ return f"StagePath('{self.as_posix()}')"
114
+
91
115
  def __eq__(self, other: object) -> bool:
92
116
  if not isinstance(other, StagePath):
93
117
  raise NotImplementedError
@@ -96,24 +120,16 @@ class StagePath:
96
120
  def __fspath__(self) -> str:
97
121
  return self._compose_path(self._path)
98
122
 
99
- def joinpath(self, *args: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
123
+ def joinpath(self, *args: Union[str, PathLike[str]]) -> "StagePath":
100
124
  path = self
101
125
  for arg in args:
102
126
  path = path._make_child(arg)
103
127
  return path
104
128
 
105
- def _make_child(self, path: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
129
+ def _make_child(self, path: Union[str, PathLike[str]]) -> "StagePath":
106
130
  stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
107
131
  if self.root == stage_path.root:
108
132
  child_path = self._path.joinpath(stage_path._path)
109
133
  return StagePath(self._compose_path(child_path))
110
134
  else:
111
135
  return stage_path
112
-
113
-
114
- def identify_stage_path(path: str) -> Union[StagePath, Path]:
115
- try:
116
- stage_path = StagePath(path)
117
- except ValueError:
118
- return Path(path)
119
- return stage_path
@@ -1,8 +1,7 @@
1
+ import os
1
2
  from dataclasses import dataclass
2
3
  from pathlib import PurePath
3
- from typing import Literal, Optional, Union
4
-
5
- from snowflake.ml.jobs._utils import stage_utils
4
+ from typing import Iterator, Literal, Optional, Protocol, Union, runtime_checkable
6
5
 
7
6
  JOB_STATUS = Literal[
8
7
  "PENDING",
@@ -15,9 +14,70 @@ JOB_STATUS = Literal[
15
14
  ]
16
15
 
17
16
 
17
+ @runtime_checkable
18
+ class PayloadPath(Protocol):
19
+ """A protocol for path-like objects used in this module, covering methods from pathlib.Path and StagePath."""
20
+
21
+ @property
22
+ def name(self) -> str:
23
+ ...
24
+
25
+ @property
26
+ def suffix(self) -> str:
27
+ ...
28
+
29
+ @property
30
+ def parent(self) -> "PayloadPath":
31
+ ...
32
+
33
+ def exists(self) -> bool:
34
+ ...
35
+
36
+ def is_file(self) -> bool:
37
+ ...
38
+
39
+ def is_absolute(self) -> bool:
40
+ ...
41
+
42
+ def absolute(self) -> "PayloadPath":
43
+ ...
44
+
45
+ def joinpath(self, *other: Union[str, os.PathLike[str]]) -> "PayloadPath":
46
+ ...
47
+
48
+ def as_posix(self) -> str:
49
+ ...
50
+
51
+ def is_relative_to(self, *other: Union[str, os.PathLike[str]]) -> bool:
52
+ ...
53
+
54
+ def relative_to(self, *other: Union[str, os.PathLike[str]]) -> PurePath:
55
+ ...
56
+
57
+ def __fspath__(self) -> str:
58
+ ...
59
+
60
+ def __str__(self) -> str:
61
+ ...
62
+
63
+ def __repr__(self) -> str:
64
+ ...
65
+
66
+
67
+ @dataclass
68
+ class PayloadSpec:
69
+ """Represents a payload item to be uploaded."""
70
+
71
+ source_path: PayloadPath
72
+ remote_relative_path: Optional[PurePath] = None
73
+
74
+ def __iter__(self) -> Iterator[Union[PayloadPath, Optional[PurePath]]]:
75
+ return iter((self.source_path, self.remote_relative_path))
76
+
77
+
18
78
  @dataclass(frozen=True)
19
79
  class PayloadEntrypoint:
20
- file_path: Union[PurePath, stage_utils.StagePath]
80
+ file_path: PayloadPath
21
81
  main_func: Optional[str]
22
82
 
23
83
 
snowflake/ml/jobs/job.py CHANGED
@@ -3,12 +3,12 @@ import logging
3
3
  import os
4
4
  import time
5
5
  from functools import cached_property
6
+ from pathlib import Path
6
7
  from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
7
8
 
8
9
  import yaml
9
10
 
10
11
  from snowflake import snowpark
11
- from snowflake.connector import errors
12
12
  from snowflake.ml._internal import telemetry
13
13
  from snowflake.ml._internal.utils import identifier
14
14
  from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
@@ -70,8 +70,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
70
70
  def _compute_pool(self) -> str:
71
71
  """Get the job's compute pool name."""
72
72
  row = _get_service_info(self._session, self.id)
73
- compute_pool = row[query_helper.get_attribute_map(self._session, {"compute_pool": 5})["compute_pool"]]
74
- return cast(str, compute_pool)
73
+ return cast(str, row["compute_pool"])
75
74
 
76
75
  @property
77
76
  def _service_spec(self) -> dict[str, Any]:
@@ -97,10 +96,24 @@ class MLJob(Generic[T], SerializableSessionMixin):
97
96
  @property
98
97
  def _result_path(self) -> str:
99
98
  """Get the job's result file location."""
100
- result_path = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
101
- if result_path is None:
99
+ result_path_str = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
100
+ if result_path_str is None:
102
101
  raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
103
- return f"{self._stage_path}/{result_path}"
102
+ volume_mounts = self._container_spec["volumeMounts"]
103
+ stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
104
+
105
+ result_path = Path(result_path_str)
106
+ stage_mount = Path(stage_mount_str)
107
+ try:
108
+ relative_path = result_path.relative_to(stage_mount)
109
+ except ValueError:
110
+ if result_path.is_absolute():
111
+ raise ValueError(
112
+ f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
113
+ )
114
+ relative_path = result_path
115
+
116
+ return f"{self._stage_path}/{relative_path.as_posix()}"
104
117
 
105
118
  @overload
106
119
  def get_logs(
@@ -183,20 +196,17 @@ class MLJob(Generic[T], SerializableSessionMixin):
183
196
  start_time = time.monotonic()
184
197
  warning_shown = False
185
198
  while (status := self.status) not in TERMINAL_JOB_STATUSES:
186
- if status == "PENDING" and not warning_shown:
199
+ elapsed = time.monotonic() - start_time
200
+ if elapsed >= timeout >= 0:
201
+ raise TimeoutError(f"Job {self.name} did not complete within {timeout} seconds")
202
+ elif status == "PENDING" and not warning_shown and elapsed >= 2: # Only show warning after 2s
187
203
  pool_info = _get_compute_pool_info(self._session, self._compute_pool)
188
- requested_attributes = {"max_nodes": 3, "active_nodes": 9}
189
- if (
190
- pool_info[requested_attributes["max_nodes"]] - pool_info[requested_attributes["active_nodes"]]
191
- ) < self.min_instances:
204
+ if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
192
205
  logger.warning(
193
- f'Compute pool busy ({pool_info[requested_attributes["active_nodes"]]}'
194
- f'/{pool_info[requested_attributes["max_nodes"]]} nodes in use, '
206
+ f"Compute pool busy ({pool_info.active_nodes}/{pool_info.max_nodes} nodes in use, "
195
207
  f"{self.min_instances} nodes required). Job execution may be delayed."
196
208
  )
197
209
  warning_shown = True
198
- if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
199
- raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
200
210
  time.sleep(delay)
201
211
  delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
202
212
  return self.status
@@ -247,27 +257,21 @@ def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[in
247
257
  """Retrieve job or job instance execution status."""
248
258
  if instance_id is not None:
249
259
  # Get specific instance status
250
- rows = session._conn.run_query(
251
- "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=[job_id], _force_qmark_paramstyle=True
252
- )
253
- request_attributes = query_helper.get_attribute_map(session, {"status": 5, "instance_id": 4})
254
- if isinstance(rows, dict) and "data" in rows:
255
- for row in rows["data"]:
256
- if row[request_attributes["instance_id"]] == str(instance_id):
257
- return cast(types.JOB_STATUS, row[request_attributes["status"]])
260
+ rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
261
+ for row in rows:
262
+ if row["instance_id"] == str(instance_id):
263
+ return cast(types.JOB_STATUS, row["status"])
258
264
  raise ValueError(f"Instance {instance_id} not found in job {job_id}")
259
265
  else:
260
266
  row = _get_service_info(session, job_id)
261
- request_attributes = query_helper.get_attribute_map(session, {"status": 1})
262
- return cast(types.JOB_STATUS, row[request_attributes["status"]])
267
+ return cast(types.JOB_STATUS, row["status"])
263
268
 
264
269
 
265
270
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
266
271
  def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
267
272
  """Retrieve job execution service spec."""
268
273
  row = _get_service_info(session, job_id)
269
- requested_attributes = query_helper.get_attribute_map(session, {"spec": 6})
270
- return cast(dict[str, Any], yaml.safe_load(row[requested_attributes["spec"]]))
274
+ return cast(dict[str, Any], yaml.safe_load(row["spec"]))
271
275
 
272
276
 
273
277
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
@@ -307,18 +311,14 @@ def _get_logs(
307
311
  if limit > 0:
308
312
  params.append(limit)
309
313
  try:
310
- data = session._conn.run_query(
314
+ (row,) = query_helper.run_query(
315
+ session,
311
316
  f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
312
317
  params=params,
313
- _force_qmark_paramstyle=True,
314
318
  )
315
- if isinstance(data, dict) and "data" in data:
316
- full_log = str(data["data"][0][0])
317
- # pass type check
318
- else:
319
- full_log = ""
320
- except errors.ProgrammingError as e:
321
- if "Container Status: PENDING" in str(e):
319
+ full_log = str(row[0])
320
+ except SnowparkSQLException as e:
321
+ if "Container Status: PENDING" in e.message:
322
322
  logger.warning("Waiting for container to start. Logs will be shown when available.")
323
323
  return ""
324
324
  else:
@@ -399,7 +399,7 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
399
399
 
400
400
  try:
401
401
  target_instances = _get_target_instances(session, job_id)
402
- except errors.ProgrammingError:
402
+ except SnowparkSQLException:
403
403
  # service may be deleted
404
404
  raise RuntimeError("Couldn’t retrieve service information")
405
405
 
@@ -407,34 +407,32 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
407
407
  return 0
408
408
 
409
409
  try:
410
- rows = session._conn.run_query(
411
- "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,), _force_qmark_paramstyle=True
410
+ rows = query_helper.run_query(
411
+ session,
412
+ "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)",
413
+ params=(job_id,),
412
414
  )
413
- except errors.ProgrammingError:
415
+ except SnowparkSQLException:
414
416
  # service may be deleted
415
417
  raise RuntimeError("Couldn’t retrieve instances")
416
418
 
417
- if not rows or not isinstance(rows, dict) or not rows.get("data"):
419
+ if not rows:
418
420
  return None
419
421
 
420
- if target_instances > len(rows["data"]):
422
+ if target_instances > len(rows):
421
423
  raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
422
424
 
423
- requested_attributes = query_helper.get_attribute_map(session, {"start_time": 8, "instance_id": 4})
424
425
  # Sort by start_time first, then by instance_id
425
426
  try:
426
- sorted_instances = sorted(
427
- rows["data"],
428
- key=lambda x: (x[requested_attributes["start_time"]], int(x[requested_attributes["instance_id"]])),
429
- )
427
+ sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
430
428
  except TypeError:
431
429
  raise RuntimeError("Job instance information unavailable.")
432
430
  head_instance = sorted_instances[0]
433
- if not head_instance[requested_attributes["start_time"]]:
431
+ if not head_instance["start_time"]:
434
432
  # If head instance hasn't started yet, return None
435
433
  return None
436
434
  try:
437
- return int(head_instance[requested_attributes["instance_id"]])
435
+ return int(head_instance["instance_id"])
438
436
  except (ValueError, TypeError):
439
437
  return 0
440
438
 
@@ -446,14 +444,16 @@ def _get_service_log_from_event_table(
446
444
  schema: Optional[str] = None,
447
445
  instance_id: Optional[int] = None,
448
446
  limit: int = -1,
449
- ) -> Any:
450
- params: list[Any] = [
451
- name,
452
- ]
447
+ ) -> list[Row]:
448
+ event_table_name = session.sql("SHOW PARAMETERS LIKE 'event_table' IN ACCOUNT").collect()[0]["value"]
453
449
  query = [
454
- "SELECT VALUE FROM snowflake.telemetry.events_view",
450
+ "SELECT VALUE FROM IDENTIFIER(?)",
455
451
  'WHERE RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
456
452
  ]
453
+ params: list[Any] = [
454
+ event_table_name,
455
+ name,
456
+ ]
457
457
  if database:
458
458
  query.append('AND RESOURCE_ATTRIBUTES:"snow.database.name" = ?')
459
459
  params.append(database)
@@ -473,23 +473,22 @@ def _get_service_log_from_event_table(
473
473
  if limit > 0:
474
474
  query.append("LIMIT ?")
475
475
  params.append(limit)
476
- rows = session._conn.run_query(
477
- "\n".join(line for line in query if line), params=params, _force_qmark_paramstyle=True
476
+ # the wrap used in query_helper does not have return type.
477
+ # sticking a # type: ignore[no-any-return] is to pass type check
478
+ rows = query_helper.run_query(
479
+ session,
480
+ "\n".join(line for line in query if line),
481
+ params=params,
478
482
  )
479
- if not rows or not isinstance(rows, dict) or not rows.get("data"):
480
- return []
481
- return rows["data"]
483
+ return rows # type: ignore[no-any-return]
482
484
 
483
485
 
484
486
  def _get_service_info(session: snowpark.Session, job_id: str) -> Any:
485
- row = session._conn.run_query("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,), _force_qmark_paramstyle=True)
486
- # pass the type check
487
- if not row or not isinstance(row, dict) or not row.get("data"):
488
- raise errors.ProgrammingError("failed to retrieve service information")
489
- return row["data"][0]
487
+ (row,) = query_helper.run_query(session, "DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,))
488
+ return row
490
489
 
491
490
 
492
- def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Any:
491
+ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
493
492
  """
494
493
  Check if the compute pool has enough available instances.
495
494
 
@@ -498,19 +497,16 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Any:
498
497
  compute_pool (str): The name of the compute pool.
499
498
 
500
499
  Returns:
501
- Any: The compute pool information.
500
+ Row: The compute pool information.
502
501
 
503
502
  Raises:
504
503
  ValueError: If the compute pool is not found.
505
504
  """
506
505
  try:
507
- compute_pool_info = session._conn.run_query(
508
- "SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,), _force_qmark_paramstyle=True
509
- )
510
- # pass the type check
511
- if not compute_pool_info or not isinstance(compute_pool_info, dict) or not compute_pool_info.get("data"):
512
- raise ValueError(f"Compute pool '{compute_pool}' not found")
513
- return compute_pool_info["data"][0]
506
+ # the wrap used in query_helper does not have return type.
507
+ # sticking a # type: ignore[no-any-return] is to pass type check
508
+ (pool_info,) = query_helper.run_query(session, "SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,))
509
+ return pool_info # type: ignore[no-any-return]
514
510
  except ValueError as e:
515
511
  if "not enough values to unpack" in str(e):
516
512
  raise ValueError(f"Compute pool '{compute_pool}' not found")
@@ -520,8 +516,7 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Any:
520
516
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
521
517
  def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
522
518
  row = _get_service_info(session, job_id)
523
- requested_attributes = query_helper.get_attribute_map(session, {"target_instances": 9})
524
- return int(row[requested_attributes["target_instances"]])
519
+ return int(row["target_instances"])
525
520
 
526
521
 
527
522
  def _get_logs_spcs(
@@ -8,7 +8,6 @@ import pandas as pd
8
8
  import yaml
9
9
 
10
10
  from snowflake import snowpark
11
- from snowflake.connector import errors
12
11
  from snowflake.ml._internal import telemetry
13
12
  from snowflake.ml._internal.utils import identifier
14
13
  from snowflake.ml.jobs import job as jb
@@ -169,8 +168,8 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
169
168
  job = jb.MLJob[Any](job_id, session=session)
170
169
  _ = job._service_spec
171
170
  return job
172
- except errors.ProgrammingError as e:
173
- if "does not exist" in str(e):
171
+ except SnowparkSQLException as e:
172
+ if "does not exist" in e.message:
174
173
  raise ValueError(f"Job does not exist: {job_id}") from e
175
174
  raise
176
175
 
@@ -186,7 +185,7 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
186
185
  logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
187
186
  except Exception as e:
188
187
  logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
189
- session._conn.run_query("DROP SERVICE IDENTIFIER(?)", params=(job.id,), _force_qmark_paramstyle=True)
188
+ query_helper.run_query(session, "DROP SERVICE IDENTIFIER(?)", params=(job.id,))
190
189
 
191
190
 
192
191
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
@@ -426,12 +425,18 @@ def _submit_job(
426
425
  An object representing the submitted job.
427
426
 
428
427
  Raises:
429
- RuntimeError: If required Snowflake features are not enabled.
430
428
  ValueError: If database or schema value(s) are invalid
431
- errors.ProgrammingError: if the SQL query or its parameters are invalid
429
+ SnowparkSQLException: If there is an error submitting the job.
432
430
  """
433
431
  session = session or get_active_session()
434
432
 
433
+ # Check for deprecated args
434
+ if "num_instances" in kwargs:
435
+ logger.warning(
436
+ "'num_instances' is deprecated and will be removed in a future release. Use 'target_instances' instead."
437
+ )
438
+ target_instances = max(target_instances, kwargs.pop("num_instances"))
439
+
435
440
  # Use kwargs for less common optional parameters
436
441
  database = kwargs.pop("database", None)
437
442
  schema = kwargs.pop("schema", None)
@@ -442,13 +447,10 @@ def _submit_job(
442
447
  spec_overrides = kwargs.pop("spec_overrides", None)
443
448
  enable_metrics = kwargs.pop("enable_metrics", True)
444
449
  query_warehouse = kwargs.pop("query_warehouse", None)
450
+ additional_payloads = kwargs.pop("additional_payloads", None)
445
451
 
446
- # Check for deprecated args
447
- if "num_instances" in kwargs:
448
- logger.warning(
449
- "'num_instances' is deprecated and will be removed in a future release. Use 'target_instances' instead."
450
- )
451
- target_instances = max(target_instances, kwargs.pop("num_instances"))
452
+ if additional_payloads:
453
+ logger.warning("'additional_payloads' is in private preview since 1.9.1. Do not use it in production.")
452
454
 
453
455
  # Warn if there are unknown kwargs
454
456
  if kwargs:
@@ -464,8 +466,7 @@ def _submit_job(
464
466
  if min_instances > 1:
465
467
  # Validate min_instances against compute pool max_nodes
466
468
  pool_info = jb._get_compute_pool_info(session, compute_pool)
467
- requested_attributes = query_helper.get_attribute_map(session, {"max_nodes": 3})
468
- max_nodes = int(pool_info[requested_attributes["max_nodes"]])
469
+ max_nodes = int(pool_info["max_nodes"])
469
470
  if min_instances > max_nodes:
470
471
  raise ValueError(
471
472
  f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
@@ -480,9 +481,7 @@ def _submit_job(
480
481
 
481
482
  # Upload payload
482
483
  uploaded_payload = payload_utils.JobPayload(
483
- source,
484
- entrypoint=entrypoint,
485
- pip_requirements=pip_requirements,
484
+ source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads
486
485
  ).upload(session, stage_path)
487
486
 
488
487
  # Generate service spec
@@ -502,7 +501,48 @@ def _submit_job(
502
501
  if spec_overrides:
503
502
  spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
504
503
 
505
- # Generate SQL command for job submission
504
+ query_text, params = _generate_submission_query(
505
+ spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
506
+ )
507
+ try:
508
+ _ = query_helper.run_query(session, query_text, params=params)
509
+ except SnowparkSQLException as e:
510
+ if "Invalid spec: unknown option 'resourceManagement' for 'spec'." in e.message:
511
+ logger.warning("Dropping 'resourceManagement' from spec because control policy is not enabled.")
512
+ spec["spec"].pop("resourceManagement", None)
513
+ query_text, params = _generate_submission_query(
514
+ spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
515
+ )
516
+ _ = query_helper.run_query(session, query_text, params=params)
517
+ else:
518
+ raise
519
+ return get_job(job_id, session=session)
520
+
521
+
522
+ def _generate_submission_query(
523
+ spec: dict[str, Any],
524
+ external_access_integrations: list[str],
525
+ query_warehouse: Optional[str],
526
+ target_instances: int,
527
+ session: snowpark.Session,
528
+ compute_pool: str,
529
+ job_id: str,
530
+ ) -> tuple[str, list[Any]]:
531
+ """
532
+ Generate the SQL query for job submission.
533
+
534
+ Args:
535
+ spec: The service spec for the job.
536
+ external_access_integrations: The external access integrations for the job.
537
+ query_warehouse: The query warehouse for the job.
538
+ target_instances: The number of instances for the job.
539
+ session: The Snowpark session to use.
540
+ compute_pool: The compute pool to use for the job.
541
+ job_id: The ID of the job.
542
+
543
+ Returns:
544
+ A tuple containing the SQL query text and the parameters for the query.
545
+ """
506
546
  query_template = textwrap.dedent(
507
547
  """\
508
548
  EXECUTE JOB SERVICE
@@ -526,17 +566,5 @@ def _submit_job(
526
566
  if target_instances > 1:
527
567
  query.append("REPLICAS = ?")
528
568
  params.append(target_instances)
529
-
530
- # Submit job
531
569
  query_text = "\n".join(line for line in query if line)
532
-
533
- try:
534
- _ = session._conn.run_query(query_text, params=params, _force_qmark_paramstyle=True)
535
- except errors.ProgrammingError as e:
536
- if "invalid property 'ASYNC'" in str(e):
537
- raise RuntimeError(
538
- "SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
539
- ) from e
540
- raise
541
-
542
- return get_job(job_id, session=session)
570
+ return query_text, params