snowflake-ml-python 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
  2. snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
  3. snowflake/ml/_internal/telemetry.py +3 -2
  4. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
  5. snowflake/ml/experiment/callback/keras.py +3 -0
  6. snowflake/ml/experiment/callback/lightgbm.py +3 -0
  7. snowflake/ml/experiment/callback/xgboost.py +3 -0
  8. snowflake/ml/experiment/experiment_tracking.py +19 -7
  9. snowflake/ml/feature_store/feature_store.py +236 -61
  10. snowflake/ml/jobs/__init__.py +4 -0
  11. snowflake/ml/jobs/_interop/__init__.py +0 -0
  12. snowflake/ml/jobs/_interop/data_utils.py +124 -0
  13. snowflake/ml/jobs/_interop/dto_schema.py +95 -0
  14. snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
  15. snowflake/ml/jobs/_interop/legacy.py +225 -0
  16. snowflake/ml/jobs/_interop/protocols.py +471 -0
  17. snowflake/ml/jobs/_interop/results.py +51 -0
  18. snowflake/ml/jobs/_interop/utils.py +144 -0
  19. snowflake/ml/jobs/_utils/constants.py +16 -2
  20. snowflake/ml/jobs/_utils/feature_flags.py +37 -5
  21. snowflake/ml/jobs/_utils/payload_utils.py +8 -2
  22. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
  23. snowflake/ml/jobs/_utils/spec_utils.py +2 -1
  24. snowflake/ml/jobs/_utils/stage_utils.py +4 -0
  25. snowflake/ml/jobs/_utils/types.py +15 -0
  26. snowflake/ml/jobs/job.py +186 -40
  27. snowflake/ml/jobs/manager.py +48 -39
  28. snowflake/ml/model/__init__.py +19 -0
  29. snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
  30. snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
  31. snowflake/ml/model/_client/model/model_version_impl.py +168 -18
  32. snowflake/ml/model/_client/ops/model_ops.py +4 -0
  33. snowflake/ml/model/_client/ops/service_ops.py +3 -0
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  36. snowflake/ml/model/_client/sql/model_version.py +3 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +11 -3
  39. snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
  40. snowflake/ml/model/_packager/model_env/model_env.py +22 -5
  41. snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
  42. snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
  43. snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
  44. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +7 -0
  45. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  46. snowflake/ml/model/type_hints.py +16 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
  48. snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
  49. snowflake/ml/version.py +1 -1
  50. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +50 -4
  51. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +54 -45
  52. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
  53. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
  54. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py CHANGED
@@ -12,12 +12,19 @@ from snowflake import snowpark
12
12
  from snowflake.ml._internal import telemetry
13
13
  from snowflake.ml._internal.utils import identifier
14
14
  from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
15
- from snowflake.ml.jobs._utils import constants, interop_utils, query_helper, types
15
+ from snowflake.ml.jobs._interop import results as interop_result, utils as interop_utils
16
+ from snowflake.ml.jobs._utils import (
17
+ constants,
18
+ payload_utils,
19
+ query_helper,
20
+ stage_utils,
21
+ types,
22
+ )
16
23
  from snowflake.snowpark import Row, context as sp_context
17
24
  from snowflake.snowpark.exceptions import SnowparkSQLException
18
25
 
19
26
  _PROJECT = "MLJob"
20
- TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"}
27
+ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR", "DELETED"}
21
28
 
22
29
  T = TypeVar("T")
23
30
 
@@ -36,7 +43,12 @@ class MLJob(Generic[T], SerializableSessionMixin):
36
43
  self._session = session or sp_context.get_active_session()
37
44
 
38
45
  self._status: types.JOB_STATUS = "PENDING"
39
- self._result: Optional[interop_utils.ExecutionResult] = None
46
+ self._result: Optional[interop_result.ExecutionResult] = None
47
+
48
+ @cached_property
49
+ def _service_info(self) -> types.ServiceInfo:
50
+ """Get the job's service info."""
51
+ return _resolve_service_info(self.id, self._session)
40
52
 
41
53
  @cached_property
42
54
  def name(self) -> str:
@@ -44,7 +56,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
44
56
 
45
57
  @cached_property
46
58
  def target_instances(self) -> int:
47
- return _get_target_instances(self._session, self.id)
59
+ return self._service_info.target_instances
48
60
 
49
61
  @cached_property
50
62
  def min_instances(self) -> int:
@@ -69,8 +81,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
69
81
  @cached_property
70
82
  def _compute_pool(self) -> str:
71
83
  """Get the job's compute pool name."""
72
- row = _get_service_info(self._session, self.id)
73
- return cast(str, row["compute_pool"])
84
+ return self._service_info.compute_pool
74
85
 
75
86
  @property
76
87
  def _service_spec(self) -> dict[str, Any]:
@@ -82,7 +93,13 @@ class MLJob(Generic[T], SerializableSessionMixin):
82
93
  @property
83
94
  def _container_spec(self) -> dict[str, Any]:
84
95
  """Get the job's main container spec."""
85
- containers = self._service_spec["spec"]["containers"]
96
+ try:
97
+ containers = self._service_spec["spec"]["containers"]
98
+ except SnowparkSQLException as e:
99
+ if e.sql_error_code == 2003:
100
+ # If the job is deleted, the service spec is not available
101
+ return {}
102
+ raise
86
103
  if len(containers) == 1:
87
104
  return cast(dict[str, Any], containers[0])
88
105
  try:
@@ -103,24 +120,30 @@ class MLJob(Generic[T], SerializableSessionMixin):
103
120
  """Get the job's result file location."""
104
121
  result_path_str = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
105
122
  if result_path_str is None:
106
- raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
107
-
108
- # If result path is relative, it is relative to the stage mount path
109
- result_path = Path(result_path_str)
110
- if not result_path.is_absolute():
111
- return f"{self._stage_path}/{result_path.as_posix()}"
112
-
113
- # If result path is absolute, it is relative to the stage mount path
123
+ raise NotImplementedError(f"Job {self.name} doesn't have a result path configured")
124
+
125
+ return self._transform_path(result_path_str)
126
+
127
+ def _transform_path(self, path_str: str) -> str:
128
+ """Transform a local path within the container to a stage path."""
129
+ path = payload_utils.resolve_path(path_str)
130
+ if isinstance(path, stage_utils.StagePath):
131
+ # Stage paths need no transformation
132
+ return path.as_posix()
133
+ if not path.is_absolute():
134
+ # Assume relative paths are relative to stage mount path
135
+ return f"{self._stage_path}/{path.as_posix()}"
136
+
137
+ # If result path is absolute, rebase it onto the stage mount path
138
+ # TODO: Rather than matching by name, use the longest mount path which matches
114
139
  volume_mounts = self._container_spec["volumeMounts"]
115
140
  stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
116
141
  stage_mount = Path(stage_mount_str)
117
142
  try:
118
- relative_path = result_path.relative_to(stage_mount)
143
+ relative_path = path.relative_to(stage_mount)
119
144
  return f"{self._stage_path}/{relative_path.as_posix()}"
120
145
  except ValueError:
121
- raise ValueError(
122
- f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
123
- )
146
+ raise ValueError(f"Result path {path} is absolute, but should be relative to stage mount {stage_mount}")
124
147
 
125
148
  @overload
126
149
  def get_logs(
@@ -165,7 +188,14 @@ class MLJob(Generic[T], SerializableSessionMixin):
165
188
  Returns:
166
189
  The job's execution logs.
167
190
  """
168
- logs = _get_logs(self._session, self.id, limit, instance_id, self._container_spec["name"], verbose)
191
+ logs = _get_logs(
192
+ self._session,
193
+ self.id,
194
+ limit,
195
+ instance_id,
196
+ self._container_spec["name"] if "name" in self._container_spec else constants.DEFAULT_CONTAINER_NAME,
197
+ verbose,
198
+ )
169
199
  assert isinstance(logs, str) # mypy
170
200
  if as_list:
171
201
  return logs.splitlines()
@@ -199,8 +229,22 @@ class MLJob(Generic[T], SerializableSessionMixin):
199
229
  Raises:
200
230
  TimeoutError: If the job does not complete within the specified timeout.
201
231
  """
202
- delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
203
232
  start_time = time.monotonic()
233
+ try:
234
+ # spcs_wait_for() is a synchronous query, it’s more effective to do polling with exponential
235
+ # backoff. If the job is running for a long time. We want a hybrid option: use spcs_wait_for()
236
+ # for the first 30 seconds, then switch to polling for long running jobs
237
+ min_timeout = (
238
+ int(min(timeout, constants.JOB_SPCS_TIMEOUT_SECONDS))
239
+ if timeout >= 0
240
+ else constants.JOB_SPCS_TIMEOUT_SECONDS
241
+ )
242
+ query_helper.run_query(self._session, f"call {self.id}!spcs_wait_for('DONE', {min_timeout})")
243
+ return self.status
244
+ except SnowparkSQLException:
245
+ # if the function does not support for this environment
246
+ pass
247
+ delay: float = float(constants.JOB_POLL_INITIAL_DELAY_SECONDS) # Start with 5s delay
204
248
  warning_shown = False
205
249
  while (status := self.status) not in TERMINAL_JOB_STATUSES:
206
250
  elapsed = time.monotonic() - start_time
@@ -218,7 +262,6 @@ class MLJob(Generic[T], SerializableSessionMixin):
218
262
  delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
219
263
  return self.status
220
264
 
221
- @snowpark._internal.utils.private_preview(version="1.8.2")
222
265
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
223
266
  def result(self, timeout: float = -1) -> T:
224
267
  """
@@ -237,13 +280,13 @@ class MLJob(Generic[T], SerializableSessionMixin):
237
280
  if self._result is None:
238
281
  self.wait(timeout)
239
282
  try:
240
- self._result = interop_utils.fetch_result(self._session, self._result_path)
283
+ self._result = interop_utils.load_result(
284
+ self._result_path, session=self._session, path_transform=self._transform_path
285
+ )
241
286
  except Exception as e:
242
- raise RuntimeError(f"Failed to retrieve result for job (id={self.name})") from e
287
+ raise RuntimeError(f"Failed to retrieve result for job, error: {e!r}") from e
243
288
 
244
- if self._result.success:
245
- return cast(T, self._result.result)
246
- raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
289
+ return cast(T, self._result.get_value())
247
290
 
248
291
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
249
292
  def cancel(self) -> None:
@@ -256,22 +299,28 @@ class MLJob(Generic[T], SerializableSessionMixin):
256
299
  self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect()
257
300
  logger.debug(f"Cancellation requested for job {self.id}")
258
301
  except SnowparkSQLException as e:
259
- raise RuntimeError(f"Failed to cancel job {self.id}: {e.message}") from e
302
+ raise RuntimeError(f"Failed to cancel job, error: {e!r}") from e
260
303
 
261
304
 
262
305
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
263
306
  def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
264
307
  """Retrieve job or job instance execution status."""
265
- if instance_id is not None:
266
- # Get specific instance status
267
- rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
268
- for row in rows:
269
- if row["instance_id"] == str(instance_id):
270
- return cast(types.JOB_STATUS, row["status"])
271
- raise ValueError(f"Instance {instance_id} not found in job {job_id}")
272
- else:
273
- row = _get_service_info(session, job_id)
274
- return cast(types.JOB_STATUS, row["status"])
308
+ try:
309
+ if instance_id is not None:
310
+ # Get specific instance status
311
+ rows = query_helper.run_query(session, "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,))
312
+ for row in rows:
313
+ if row["instance_id"] == str(instance_id):
314
+ return cast(types.JOB_STATUS, row["status"])
315
+ raise ValueError(f"Instance {instance_id} not found in job {job_id}")
316
+ else:
317
+ row = _get_service_info(session, job_id)
318
+ return cast(types.JOB_STATUS, row["status"])
319
+ except SnowparkSQLException as e:
320
+ if e.sql_error_code == 2003:
321
+ row = _get_service_info_spcs(session, job_id)
322
+ return cast(types.JOB_STATUS, row["STATUS"])
323
+ raise
275
324
 
276
325
 
277
326
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
@@ -542,8 +591,21 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
542
591
 
543
592
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
544
593
  def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
545
- row = _get_service_info(session, job_id)
546
- return int(row["target_instances"])
594
+ try:
595
+ row = _get_service_info(session, job_id)
596
+ return int(row["target_instances"])
597
+ except SnowparkSQLException as e:
598
+ if e.sql_error_code == 2003:
599
+ row = _get_service_info_spcs(session, job_id)
600
+ try:
601
+ params = json.loads(row["PARAMETERS"])
602
+ if isinstance(params, dict):
603
+ return int(params.get("REPLICAS", 1))
604
+ else:
605
+ return 1
606
+ except (json.JSONDecodeError, ValueError):
607
+ return 1
608
+ raise
547
609
 
548
610
 
549
611
  def _get_logs_spcs(
@@ -581,3 +643,87 @@ def _get_logs_spcs(
581
643
  query.append(f" LIMIT {limit};")
582
644
  rows = session.sql("\n".join(query)).collect()
583
645
  return rows
646
+
647
+
648
+ def _get_service_info_spcs(session: snowpark.Session, job_id: str) -> Any:
649
+ """
650
+ Retrieve the service info from the SPCS interface.
651
+
652
+ Args:
653
+ session (Session): The Snowpark session to use.
654
+ job_id (str): The job ID.
655
+
656
+ Returns:
657
+ Any: The service info.
658
+
659
+ Raises:
660
+ SnowparkSQLException: If the job does not exist or is too old to retrieve.
661
+ """
662
+ db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
663
+ db = db or session.get_current_database()
664
+ schema = schema or session.get_current_schema()
665
+ rows = query_helper.run_query(
666
+ session,
667
+ """
668
+ select DATABASE_NAME, SCHEMA_NAME, NAME, STATUS, COMPUTE_POOL_NAME, PARAMETERS
669
+ from table(snowflake.spcs.get_job_history())
670
+ where database_name = ? and schema_name = ? and name = ?
671
+ """,
672
+ params=(db, schema, name),
673
+ )
674
+ if rows:
675
+ return rows[0]
676
+ else:
677
+ raise SnowparkSQLException(f"Job {job_id} does not exist or could not be retrieved", sql_error_code=2003)
678
+
679
+
680
+ def _resolve_service_info(id: str, session: snowpark.Session) -> types.ServiceInfo:
681
+ try:
682
+ row = _get_service_info(session, id)
683
+ except SnowparkSQLException as e:
684
+ if e.sql_error_code == 2003:
685
+ row = _get_service_info_spcs(session, id)
686
+ else:
687
+ raise
688
+ if not row:
689
+ raise SnowparkSQLException(f"Job {id} does not exist or could not be retrieved", sql_error_code=2003)
690
+
691
+ if "compute_pool" in row:
692
+ compute_pool = row["compute_pool"]
693
+ elif "COMPUTE_POOL_NAME" in row:
694
+ compute_pool = row["COMPUTE_POOL_NAME"]
695
+ else:
696
+ raise ValueError(f"compute_pool not found in row: {row}")
697
+
698
+ if "status" in row:
699
+ status = row["status"]
700
+ elif "STATUS" in row:
701
+ status = row["STATUS"]
702
+ else:
703
+ raise ValueError(f"status not found in row: {row}")
704
+ # Normalize target_instances
705
+ target_instances: int
706
+ if "target_instances" in row and row["target_instances"] is not None:
707
+ try:
708
+ target_instances = int(row["target_instances"])
709
+ except (ValueError, TypeError):
710
+ target_instances = 1
711
+ elif "PARAMETERS" in row and row["PARAMETERS"]:
712
+ try:
713
+ params = json.loads(row["PARAMETERS"])
714
+ target_instances = int(params.get("REPLICAS", 1)) if isinstance(params, dict) else 1
715
+ except (json.JSONDecodeError, ValueError, TypeError):
716
+ target_instances = 1
717
+ else:
718
+ target_instances = 1
719
+
720
+ database_name = row["database_name"] if "database_name" in row else row["DATABASE_NAME"]
721
+ schema_name = row["schema_name"] if "schema_name" in row else row["SCHEMA_NAME"]
722
+
723
+ return types.ServiceInfo(
724
+ database_name=database_name,
725
+ schema_name=schema_name,
726
+ status=cast(types.JOB_STATUS, status),
727
+ compute_pool=cast(str, compute_pool),
728
+ target_instances=target_instances,
729
+ )
@@ -21,6 +21,7 @@ from snowflake.ml.jobs._utils import (
21
21
  spec_utils,
22
22
  types,
23
23
  )
24
+ from snowflake.snowpark._internal import utils as sp_utils
24
25
  from snowflake.snowpark.context import get_active_session
25
26
  from snowflake.snowpark.exceptions import SnowparkSQLException
26
27
  from snowflake.snowpark.functions import coalesce, col, lit, when
@@ -179,8 +180,10 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
179
180
  _ = job._service_spec
180
181
  return job
181
182
  except SnowparkSQLException as e:
182
- if "does not exist" in e.message:
183
- raise ValueError(f"Job does not exist: {job_id}") from e
183
+ if e.sql_error_code == 2003:
184
+ job = jb.MLJob[Any](job_id, session=session)
185
+ _ = job.status
186
+ return job
184
187
  raise
185
188
 
186
189
 
@@ -446,7 +449,7 @@ def _submit_job(
446
449
  Raises:
447
450
  ValueError: If database or schema value(s) are invalid
448
451
  RuntimeError: If schema is not specified in session context or job submission
449
- snowpark.exceptions.SnowparkSQLException: if failed to upload payload
452
+ SnowparkSQLException: if failed to upload payload
450
453
  """
451
454
  session = _ensure_session(session)
452
455
 
@@ -512,49 +515,44 @@ def _submit_job(
512
515
  uploaded_payload = payload_utils.JobPayload(
513
516
  source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports
514
517
  ).upload(session, stage_path)
515
- except snowpark.exceptions.SnowparkSQLException as e:
518
+ except SnowparkSQLException as e:
516
519
  if e.sql_error_code == 90106:
517
520
  raise RuntimeError(
518
521
  "Please specify a schema, either in the session context or as a parameter in the job submission"
519
522
  )
520
523
  raise
521
524
 
522
- # FIXME: Temporary patches, remove this after v1 is deprecated
523
- if target_instances > 1:
524
- default_spec_overrides = {
525
- "spec": {
526
- "endpoints": [
527
- {"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
528
- ]
529
- },
530
- }
531
- if spec_overrides:
532
- spec_overrides = spec_utils.merge_patch(
533
- default_spec_overrides, spec_overrides, display_name="spec_overrides"
534
- )
535
- else:
536
- spec_overrides = default_spec_overrides
537
-
538
- if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
525
+ if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled(default=True):
539
526
  # Add default env vars (extracted from spec_utils.generate_service_spec)
540
527
  combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
541
528
 
542
- return _do_submit_job_v2(
543
- session=session,
544
- payload=uploaded_payload,
545
- args=args,
546
- env_vars=combined_env_vars,
547
- spec_overrides=spec_overrides,
548
- compute_pool=compute_pool,
549
- job_id=job_id,
550
- external_access_integrations=external_access_integrations,
551
- query_warehouse=query_warehouse,
552
- target_instances=target_instances,
553
- min_instances=min_instances,
554
- enable_metrics=enable_metrics,
555
- use_async=True,
556
- runtime_environment=runtime_environment,
557
- )
529
+ try:
530
+ return _do_submit_job_v2(
531
+ session=session,
532
+ payload=uploaded_payload,
533
+ args=args,
534
+ env_vars=combined_env_vars,
535
+ spec_overrides=spec_overrides,
536
+ compute_pool=compute_pool,
537
+ job_id=job_id,
538
+ external_access_integrations=external_access_integrations,
539
+ query_warehouse=query_warehouse,
540
+ target_instances=target_instances,
541
+ min_instances=min_instances,
542
+ enable_metrics=enable_metrics,
543
+ use_async=True,
544
+ runtime_environment=runtime_environment,
545
+ )
546
+ except SnowparkSQLException as e:
547
+ if not (e.sql_error_code == 90237 and sp_utils.is_in_stored_procedure()): # type: ignore[no-untyped-call]
548
+ raise
549
+ # SNOW-2390287: SYSTEM$EXECUTE_ML_JOB() is erroneously blocked in owner's rights
550
+ # stored procedures. This will be fixed in an upcoming release.
551
+ logger.warning(
552
+ "Job submission using V2 failed with error {}. Falling back to V1.".format(
553
+ str(e).split("\n", 1)[0],
554
+ )
555
+ )
558
556
 
559
557
  # Fall back to v1
560
558
  # Generate service spec
@@ -688,7 +686,7 @@ def _do_submit_job_v2(
688
686
  # for the image tag or full image URL, we use that directly
689
687
  if runtime_environment:
690
688
  spec_options["RUNTIME"] = runtime_environment
691
- elif feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled():
689
+ elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
692
690
  # when feature flag is enabled, we get the local python version and wrap it in a dict
693
691
  # in system function, we can know whether it is python version or image tag or full image URL through the format
694
692
  spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
@@ -699,10 +697,21 @@ def _do_submit_job_v2(
699
697
  "MIN_INSTANCES": min_instances,
700
698
  "ASYNC": use_async,
701
699
  }
700
+ if payload.payload_name:
701
+ job_options["GENERATE_SUFFIX"] = True
702
702
  job_options = {k: v for k, v in job_options.items() if v is not None}
703
703
 
704
704
  query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
705
- params = [job_id, compute_pool, json.dumps(spec_options), json.dumps(job_options)]
705
+ if job_id:
706
+ database, schema, _ = identifier.parse_schema_level_object_identifier(job_id)
707
+ params = [
708
+ job_id
709
+ if payload.payload_name is None
710
+ else identifier.get_schema_level_object_identifier(database, schema, payload.payload_name) + "_",
711
+ compute_pool,
712
+ json.dumps(spec_options),
713
+ json.dumps(job_options),
714
+ ]
706
715
  actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
707
716
 
708
717
  return get_job(actual_job_id, session=session)
@@ -1,3 +1,6 @@
1
+ import sys
2
+ import warnings
3
+
1
4
  from snowflake.ml.model._client.model.batch_inference_specs import (
2
5
  JobSpec,
3
6
  OutputSpec,
@@ -18,3 +21,19 @@ __all__ = [
18
21
  "SaveMode",
19
22
  "Volatility",
20
23
  ]
24
+
25
+ _deprecation_warning_msg_for_3_9 = (
26
+ "Python 3.9 is deprecated in snowflake-ml-python. " "Please upgrade to Python 3.10 or greater."
27
+ )
28
+
29
+ warnings.filterwarnings(
30
+ "once",
31
+ message=_deprecation_warning_msg_for_3_9,
32
+ )
33
+
34
+ if sys.version_info.major == 3 and sys.version_info.minor == 9:
35
+ warnings.warn(
36
+ _deprecation_warning_msg_for_3_9,
37
+ category=DeprecationWarning,
38
+ stacklevel=2,
39
+ )
@@ -19,11 +19,74 @@ class SaveMode(str, Enum):
19
19
 
20
20
 
21
21
  class OutputSpec(BaseModel):
22
+ """Specification for batch inference output.
23
+
24
+ Defines where the inference results should be written and how to handle
25
+ existing files at the output location.
26
+
27
+ Attributes:
28
+ stage_location (str): The stage path where batch inference results will be saved.
29
+ This should be a full path including the stage with @ prefix. For example,
30
+ '@My_DB.PUBLIC.MY_STAGE/someth/path/'. A non-existent directory will be re-created.
31
+ Only Snowflake internal stages are supported at this moment.
32
+ mode (SaveMode): The save mode that determines behavior when files already exist
33
+ at the output location. Defaults to SaveMode.ERROR which raises an error
34
+ if files exist. Can be set to SaveMode.OVERWRITE to replace existing files.
35
+
36
+ Example:
37
+ >>> output_spec = OutputSpec(
38
+ ... stage_location="@My_DB.PUBLIC.MY_STAGE/someth/path/",
39
+ ... mode=SaveMode.OVERWRITE
40
+ ... )
41
+ """
42
+
22
43
  stage_location: str
23
44
  mode: SaveMode = SaveMode.ERROR
24
45
 
25
46
 
26
47
  class JobSpec(BaseModel):
48
+ """Specification for batch inference job execution.
49
+
50
+ Defines the compute resources, job settings, and execution parameters
51
+ for running batch inference jobs in Snowflake.
52
+
53
+ Attributes:
54
+ image_repo (Optional[str]): Container image repository for the inference job.
55
+ If not specified, uses the default repository.
56
+ job_name (Optional[str]): Custom name for the batch inference job.
57
+ If not provided, a name will be auto-generated in the form of "BATCH_INFERENCE_<UUID>".
58
+ num_workers (Optional[int]): The number of workers to run the inference service for handling
59
+ requests in parallel within an instance of the service. By default, it is set to 2*vCPU+1
60
+ of the node for CPU based inference and 1 for GPU based inference. For GPU based inference,
61
+ please see best practices before playing with this value.
62
+ function_name (Optional[str]): Name of the specific function to call for inference.
63
+ Required when the model has multiple inference functions.
64
+ force_rebuild (bool): Whether to force rebuilding the container image even if
65
+ it already exists. Defaults to False.
66
+ max_batch_rows (int): Maximum number of rows to process in a single batch.
67
+ Defaults to 1024. Larger values may improve throughput.
68
+ warehouse (Optional[str]): Snowflake warehouse to use for the batch inference job.
69
+ If not specified, uses the session's current warehouse.
70
+ cpu_requests (Optional[str]): The cpu limit for CPU based inference. Can be an integer,
71
+ fractional or string values. If None, we attempt to utilize all the vCPU of the node.
72
+ memory_requests (Optional[str]): The memory limit for inference. Can be an integer
73
+ or a fractional value, but requires a unit (GiB, MiB). If None, we attempt to utilize all
74
+ the memory of the node.
75
+ gpu_requests (Optional[str]): The gpu limit for GPU based inference. Can be integer or
76
+ string values. Use CPU if None.
77
+ replicas (Optional[int]): Number of job replicas to run for high availability.
78
+ If not specified, defaults to 1 replica.
79
+
80
+ Example:
81
+ >>> job_spec = JobSpec(
82
+ ... job_name="my_inference_job",
83
+ ... num_workers=4,
84
+ ... cpu_requests="2",
85
+ ... memory_requests="8Gi",
86
+ ... max_batch_rows=2048
87
+ ... )
88
+ """
89
+
27
90
  image_repo: Optional[str] = None
28
91
  job_name: Optional[str] = None
29
92
  num_workers: Optional[int] = None
@@ -6,13 +6,9 @@ from snowflake.ml.model._client.ops import service_ops
6
6
  def _get_inference_engine_args(
7
7
  experimental_options: Optional[dict[str, Any]],
8
8
  ) -> Optional[service_ops.InferenceEngineArgs]:
9
-
10
- if not experimental_options:
9
+ if not experimental_options or "inference_engine" not in experimental_options:
11
10
  return None
12
11
 
13
- if "inference_engine" not in experimental_options:
14
- raise ValueError("inference_engine is required in experimental_options")
15
-
16
12
  return service_ops.InferenceEngineArgs(
17
13
  inference_engine=experimental_options["inference_engine"],
18
14
  inference_engine_args_override=experimental_options.get("inference_engine_args_override"),