snowflake-ml-python 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
  2. snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
  3. snowflake/ml/jobs/__init__.py +4 -0
  4. snowflake/ml/jobs/_interop/__init__.py +0 -0
  5. snowflake/ml/jobs/_interop/data_utils.py +124 -0
  6. snowflake/ml/jobs/_interop/dto_schema.py +95 -0
  7. snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
  8. snowflake/ml/jobs/_interop/legacy.py +225 -0
  9. snowflake/ml/jobs/_interop/protocols.py +471 -0
  10. snowflake/ml/jobs/_interop/results.py +51 -0
  11. snowflake/ml/jobs/_interop/utils.py +144 -0
  12. snowflake/ml/jobs/_utils/constants.py +4 -1
  13. snowflake/ml/jobs/_utils/feature_flags.py +37 -5
  14. snowflake/ml/jobs/_utils/payload_utils.py +1 -1
  15. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
  16. snowflake/ml/jobs/_utils/spec_utils.py +2 -1
  17. snowflake/ml/jobs/_utils/types.py +10 -0
  18. snowflake/ml/jobs/job.py +168 -36
  19. snowflake/ml/jobs/manager.py +36 -38
  20. snowflake/ml/model/_client/model/model_version_impl.py +39 -7
  21. snowflake/ml/model/_client/ops/model_ops.py +4 -0
  22. snowflake/ml/model/_client/sql/model_version.py +3 -1
  23. snowflake/ml/model/_model_composer/model_method/model_method.py +7 -2
  24. snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
  25. snowflake/ml/model/_packager/model_env/model_env.py +22 -5
  26. snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
  27. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  28. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  29. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
  30. snowflake/ml/version.py +1 -1
  31. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +26 -4
  32. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +35 -27
  33. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
  34. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
  35. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py CHANGED
@@ -12,12 +12,19 @@ from snowflake import snowpark
12
12
  from snowflake.ml._internal import telemetry
13
13
  from snowflake.ml._internal.utils import identifier
14
14
  from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
15
- from snowflake.ml.jobs._utils import constants, interop_utils, query_helper, types
15
+ from snowflake.ml.jobs._interop import results as interop_result, utils as interop_utils
16
+ from snowflake.ml.jobs._utils import (
17
+ constants,
18
+ payload_utils,
19
+ query_helper,
20
+ stage_utils,
21
+ types,
22
+ )
16
23
  from snowflake.snowpark import Row, context as sp_context
17
24
  from snowflake.snowpark.exceptions import SnowparkSQLException
18
25
 
19
26
  _PROJECT = "MLJob"
20
- TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"}
27
+ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR", "DELETED"}
21
28
 
22
29
  T = TypeVar("T")
23
30
 
@@ -36,7 +43,12 @@ class MLJob(Generic[T], SerializableSessionMixin):
36
43
  self._session = session or sp_context.get_active_session()
37
44
 
38
45
  self._status: types.JOB_STATUS = "PENDING"
39
- self._result: Optional[interop_utils.ExecutionResult] = None
46
+ self._result: Optional[interop_result.ExecutionResult] = None
47
+
48
+ @cached_property
49
+ def _service_info(self) -> types.ServiceInfo:
50
+ """Get the job's service info."""
51
+ return _resolve_service_info(self.id, self._session)
40
52
 
41
53
  @cached_property
42
54
  def name(self) -> str:
@@ -44,7 +56,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
44
56
 
45
57
  @cached_property
46
58
  def target_instances(self) -> int:
47
- return _get_target_instances(self._session, self.id)
59
+ return self._service_info.target_instances
48
60
 
49
61
  @cached_property
50
62
  def min_instances(self) -> int:
@@ -69,8 +81,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
69
81
  @cached_property
70
82
  def _compute_pool(self) -> str:
71
83
  """Get the job's compute pool name."""
72
- row = _get_service_info(self._session, self.id)
73
- return cast(str, row["compute_pool"])
84
+ return self._service_info.compute_pool
74
85
 
75
86
  @property
76
87
  def _service_spec(self) -> dict[str, Any]:
@@ -82,7 +93,13 @@ class MLJob(Generic[T], SerializableSessionMixin):
82
93
  @property
83
94
  def _container_spec(self) -> dict[str, Any]:
84
95
  """Get the job's main container spec."""
85
- containers = self._service_spec["spec"]["containers"]
96
+ try:
97
+ containers = self._service_spec["spec"]["containers"]
98
+ except SnowparkSQLException as e:
99
+ if e.sql_error_code == 2003:
100
+ # If the job is deleted, the service spec is not available
101
+ return {}
102
+ raise
86
103
  if len(containers) == 1:
87
104
  return cast(dict[str, Any], containers[0])
88
105
  try:
@@ -105,22 +122,28 @@ class MLJob(Generic[T], SerializableSessionMixin):
105
122
  if result_path_str is None:
106
123
  raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
107
124
 
108
- # If result path is relative, it is relative to the stage mount path
109
- result_path = Path(result_path_str)
110
- if not result_path.is_absolute():
111
- return f"{self._stage_path}/{result_path.as_posix()}"
125
+ return self._transform_path(result_path_str)
112
126
 
113
- # If result path is absolute, it is relative to the stage mount path
127
+ def _transform_path(self, path_str: str) -> str:
128
+ """Transform a local path within the container to a stage path."""
129
+ path = payload_utils.resolve_path(path_str)
130
+ if isinstance(path, stage_utils.StagePath):
131
+ # Stage paths need no transformation
132
+ return path.as_posix()
133
+ if not path.is_absolute():
134
+ # Assume relative paths are relative to stage mount path
135
+ return f"{self._stage_path}/{path.as_posix()}"
136
+
137
+ # If result path is absolute, rebase it onto the stage mount path
138
+ # TODO: Rather than matching by name, use the longest mount path which matches
114
139
  volume_mounts = self._container_spec["volumeMounts"]
115
140
  stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
116
141
  stage_mount = Path(stage_mount_str)
117
142
  try:
118
- relative_path = result_path.relative_to(stage_mount)
143
+ relative_path = path.relative_to(stage_mount)
119
144
  return f"{self._stage_path}/{relative_path.as_posix()}"
120
145
  except ValueError:
121
- raise ValueError(
122
- f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
123
- )
146
+ raise ValueError(f"Result path {path} is absolute, but should be relative to stage mount {stage_mount}")
124
147
 
125
148
  @overload
126
149
  def get_logs(
@@ -165,7 +188,14 @@ class MLJob(Generic[T], SerializableSessionMixin):
165
188
  Returns:
166
189
  The job's execution logs.
167
190
  """
168
- logs = _get_logs(self._session, self.id, limit, instance_id, self._container_spec["name"], verbose)
191
+ logs = _get_logs(
192
+ self._session,
193
+ self.id,
194
+ limit,
195
+ instance_id,
196
+ self._container_spec["name"] if "name" in self._container_spec else constants.DEFAULT_CONTAINER_NAME,
197
+ verbose,
198
+ )
169
199
  assert isinstance(logs, str) # mypy
170
200
  if as_list:
171
201
  return logs.splitlines()
@@ -218,7 +248,6 @@ class MLJob(Generic[T], SerializableSessionMixin):
218
248
  delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
219
249
  return self.status
220
250
 
221
- @snowpark._internal.utils.private_preview(version="1.8.2")
222
251
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
223
252
  def result(self, timeout: float = -1) -> T:
224
253
  """
@@ -237,13 +266,13 @@ class MLJob(Generic[T], SerializableSessionMixin):
237
266
  if self._result is None:
238
267
  self.wait(timeout)
239
268
  try:
240
- self._result = interop_utils.fetch_result(self._session, self._result_path)
269
+ self._result = interop_utils.load_result(
270
+ self._result_path, session=self._session, path_transform=self._transform_path
271
+ )
241
272
  except Exception as e:
242
- raise RuntimeError(f"Failed to retrieve result for job (id={self.name})") from e
273
+ raise RuntimeError(f"Failed to retrieve result for job, error: {e!r}") from e
243
274
 
244
- if self._result.success:
245
- return cast(T, self._result.result)
246
- raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
275
+ return cast(T, self._result.get_value())
247
276
 
248
277
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
249
278
  def cancel(self) -> None:
@@ -256,22 +285,28 @@ class MLJob(Generic[T], SerializableSessionMixin):
256
285
  self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect()
257
286
  logger.debug(f"Cancellation requested for job {self.id}")
258
287
  except SnowparkSQLException as e:
259
- raise RuntimeError(f"Failed to cancel job {self.id}: {e.message}") from e
288
+ raise RuntimeError(f"Failed to cancel job, error: {e!r}") from e
260
289
 
261
290
 
262
291
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
263
292
  def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
264
293
  """Retrieve job or job instance execution status."""
265
- if instance_id is not None:
266
- # Get specific instance status
267
- rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
268
- for row in rows:
269
- if row["instance_id"] == str(instance_id):
270
- return cast(types.JOB_STATUS, row["status"])
271
- raise ValueError(f"Instance {instance_id} not found in job {job_id}")
272
- else:
273
- row = _get_service_info(session, job_id)
274
- return cast(types.JOB_STATUS, row["status"])
294
+ try:
295
+ if instance_id is not None:
296
+ # Get specific instance status
297
+ rows = query_helper.run_query(session, "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,))
298
+ for row in rows:
299
+ if row["instance_id"] == str(instance_id):
300
+ return cast(types.JOB_STATUS, row["status"])
301
+ raise ValueError(f"Instance {instance_id} not found in job {job_id}")
302
+ else:
303
+ row = _get_service_info(session, job_id)
304
+ return cast(types.JOB_STATUS, row["status"])
305
+ except SnowparkSQLException as e:
306
+ if e.sql_error_code == 2003:
307
+ row = _get_service_info_spcs(session, job_id)
308
+ return cast(types.JOB_STATUS, row["STATUS"])
309
+ raise
275
310
 
276
311
 
277
312
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
@@ -542,8 +577,21 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
542
577
 
543
578
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
544
579
  def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
545
- row = _get_service_info(session, job_id)
546
- return int(row["target_instances"])
580
+ try:
581
+ row = _get_service_info(session, job_id)
582
+ return int(row["target_instances"])
583
+ except SnowparkSQLException as e:
584
+ if e.sql_error_code == 2003:
585
+ row = _get_service_info_spcs(session, job_id)
586
+ try:
587
+ params = json.loads(row["PARAMETERS"])
588
+ if isinstance(params, dict):
589
+ return int(params.get("REPLICAS", 1))
590
+ else:
591
+ return 1
592
+ except (json.JSONDecodeError, ValueError):
593
+ return 1
594
+ raise
547
595
 
548
596
 
549
597
  def _get_logs_spcs(
@@ -581,3 +629,87 @@ def _get_logs_spcs(
581
629
  query.append(f" LIMIT {limit};")
582
630
  rows = session.sql("\n".join(query)).collect()
583
631
  return rows
632
+
633
+
634
+ def _get_service_info_spcs(session: snowpark.Session, job_id: str) -> Any:
635
+ """
636
+ Retrieve the service info from the SPCS interface.
637
+
638
+ Args:
639
+ session (Session): The Snowpark session to use.
640
+ job_id (str): The job ID.
641
+
642
+ Returns:
643
+ Any: The service info.
644
+
645
+ Raises:
646
+ SnowparkSQLException: If the job does not exist or is too old to retrieve.
647
+ """
648
+ db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
649
+ db = db or session.get_current_database()
650
+ schema = schema or session.get_current_schema()
651
+ rows = query_helper.run_query(
652
+ session,
653
+ """
654
+ select DATABASE_NAME, SCHEMA_NAME, NAME, STATUS, COMPUTE_POOL_NAME, PARAMETERS
655
+ from table(snowflake.spcs.get_job_history())
656
+ where database_name = ? and schema_name = ? and name = ?
657
+ """,
658
+ params=(db, schema, name),
659
+ )
660
+ if rows:
661
+ return rows[0]
662
+ else:
663
+ raise SnowparkSQLException(f"Job {job_id} does not exist or could not be retrieved", sql_error_code=2003)
664
+
665
+
666
+ def _resolve_service_info(id: str, session: snowpark.Session) -> types.ServiceInfo:
667
+ try:
668
+ row = _get_service_info(session, id)
669
+ except SnowparkSQLException as e:
670
+ if e.sql_error_code == 2003:
671
+ row = _get_service_info_spcs(session, id)
672
+ else:
673
+ raise
674
+ if not row:
675
+ raise SnowparkSQLException(f"Job {id} does not exist or could not be retrieved", sql_error_code=2003)
676
+
677
+ if "compute_pool" in row:
678
+ compute_pool = row["compute_pool"]
679
+ elif "COMPUTE_POOL_NAME" in row:
680
+ compute_pool = row["COMPUTE_POOL_NAME"]
681
+ else:
682
+ raise ValueError(f"compute_pool not found in row: {row}")
683
+
684
+ if "status" in row:
685
+ status = row["status"]
686
+ elif "STATUS" in row:
687
+ status = row["STATUS"]
688
+ else:
689
+ raise ValueError(f"status not found in row: {row}")
690
+ # Normalize target_instances
691
+ target_instances: int
692
+ if "target_instances" in row and row["target_instances"] is not None:
693
+ try:
694
+ target_instances = int(row["target_instances"])
695
+ except (ValueError, TypeError):
696
+ target_instances = 1
697
+ elif "PARAMETERS" in row and row["PARAMETERS"]:
698
+ try:
699
+ params = json.loads(row["PARAMETERS"])
700
+ target_instances = int(params.get("REPLICAS", 1)) if isinstance(params, dict) else 1
701
+ except (json.JSONDecodeError, ValueError, TypeError):
702
+ target_instances = 1
703
+ else:
704
+ target_instances = 1
705
+
706
+ database_name = row["database_name"] if "database_name" in row else row["DATABASE_NAME"]
707
+ schema_name = row["schema_name"] if "schema_name" in row else row["SCHEMA_NAME"]
708
+
709
+ return types.ServiceInfo(
710
+ database_name=database_name,
711
+ schema_name=schema_name,
712
+ status=cast(types.JOB_STATUS, status),
713
+ compute_pool=cast(str, compute_pool),
714
+ target_instances=target_instances,
715
+ )
@@ -21,6 +21,7 @@ from snowflake.ml.jobs._utils import (
21
21
  spec_utils,
22
22
  types,
23
23
  )
24
+ from snowflake.snowpark._internal import utils as sp_utils
24
25
  from snowflake.snowpark.context import get_active_session
25
26
  from snowflake.snowpark.exceptions import SnowparkSQLException
26
27
  from snowflake.snowpark.functions import coalesce, col, lit, when
@@ -179,8 +180,10 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
179
180
  _ = job._service_spec
180
181
  return job
181
182
  except SnowparkSQLException as e:
182
- if "does not exist" in e.message:
183
- raise ValueError(f"Job does not exist: {job_id}") from e
183
+ if e.sql_error_code == 2003:
184
+ job = jb.MLJob[Any](job_id, session=session)
185
+ _ = job.status
186
+ return job
184
187
  raise
185
188
 
186
189
 
@@ -446,7 +449,7 @@ def _submit_job(
446
449
  Raises:
447
450
  ValueError: If database or schema value(s) are invalid
448
451
  RuntimeError: If schema is not specified in session context or job submission
449
- snowpark.exceptions.SnowparkSQLException: if failed to upload payload
452
+ SnowparkSQLException: if failed to upload payload
450
453
  """
451
454
  session = _ensure_session(session)
452
455
 
@@ -512,49 +515,44 @@ def _submit_job(
512
515
  uploaded_payload = payload_utils.JobPayload(
513
516
  source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports
514
517
  ).upload(session, stage_path)
515
- except snowpark.exceptions.SnowparkSQLException as e:
518
+ except SnowparkSQLException as e:
516
519
  if e.sql_error_code == 90106:
517
520
  raise RuntimeError(
518
521
  "Please specify a schema, either in the session context or as a parameter in the job submission"
519
522
  )
520
523
  raise
521
524
 
522
- # FIXME: Temporary patches, remove this after v1 is deprecated
523
- if target_instances > 1:
524
- default_spec_overrides = {
525
- "spec": {
526
- "endpoints": [
527
- {"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
528
- ]
529
- },
530
- }
531
- if spec_overrides:
532
- spec_overrides = spec_utils.merge_patch(
533
- default_spec_overrides, spec_overrides, display_name="spec_overrides"
534
- )
535
- else:
536
- spec_overrides = default_spec_overrides
537
-
538
- if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
525
+ if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled(default=True):
539
526
  # Add default env vars (extracted from spec_utils.generate_service_spec)
540
527
  combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
541
528
 
542
- return _do_submit_job_v2(
543
- session=session,
544
- payload=uploaded_payload,
545
- args=args,
546
- env_vars=combined_env_vars,
547
- spec_overrides=spec_overrides,
548
- compute_pool=compute_pool,
549
- job_id=job_id,
550
- external_access_integrations=external_access_integrations,
551
- query_warehouse=query_warehouse,
552
- target_instances=target_instances,
553
- min_instances=min_instances,
554
- enable_metrics=enable_metrics,
555
- use_async=True,
556
- runtime_environment=runtime_environment,
557
- )
529
+ try:
530
+ return _do_submit_job_v2(
531
+ session=session,
532
+ payload=uploaded_payload,
533
+ args=args,
534
+ env_vars=combined_env_vars,
535
+ spec_overrides=spec_overrides,
536
+ compute_pool=compute_pool,
537
+ job_id=job_id,
538
+ external_access_integrations=external_access_integrations,
539
+ query_warehouse=query_warehouse,
540
+ target_instances=target_instances,
541
+ min_instances=min_instances,
542
+ enable_metrics=enable_metrics,
543
+ use_async=True,
544
+ runtime_environment=runtime_environment,
545
+ )
546
+ except SnowparkSQLException as e:
547
+ if not (e.sql_error_code == 90237 and sp_utils.is_in_stored_procedure()): # type: ignore[no-untyped-call]
548
+ raise
549
+ # SNOW-2390287: SYSTEM$EXECUTE_ML_JOB() is erroneously blocked in owner's rights
550
+ # stored procedures. This will be fixed in an upcoming release.
551
+ logger.warning(
552
+ "Job submission using V2 failed with error {}. Falling back to V1.".format(
553
+ str(e).split("\n", 1)[0],
554
+ )
555
+ )
558
556
 
559
557
  # Fall back to v1
560
558
  # Generate service spec
@@ -688,7 +686,7 @@ def _do_submit_job_v2(
688
686
  # for the image tag or full image URL, we use that directly
689
687
  if runtime_environment:
690
688
  spec_options["RUNTIME"] = runtime_environment
691
- elif feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled():
689
+ elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
692
690
  # when feature flag is enabled, we get the local python version and wrap it in a dict
693
691
  # in system function, we can know whether it is python version or image tag or full image URL through the format
694
692
  spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
@@ -19,7 +19,9 @@ from snowflake.ml.model._client.model import (
19
19
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
20
20
  from snowflake.ml.model._model_composer import model_composer
21
21
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
22
+ from snowflake.ml.model._model_composer.model_method import utils as model_method_utils
22
23
  from snowflake.ml.model._packager.model_handlers import snowmlmodel
24
+ from snowflake.ml.model._packager.model_meta import model_meta_schema
23
25
  from snowflake.snowpark import Session, async_job, dataframe
24
26
 
25
27
  _TELEMETRY_PROJECT = "MLOps"
@@ -41,6 +43,7 @@ class ModelVersion(lineage_node.LineageNode):
41
43
  _model_name: sql_identifier.SqlIdentifier
42
44
  _version_name: sql_identifier.SqlIdentifier
43
45
  _functions: list[model_manifest_schema.ModelFunctionInfo]
46
+ _model_spec: Optional[model_meta_schema.ModelMetadataDict]
44
47
 
45
48
  def __init__(self) -> None:
46
49
  raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
@@ -150,6 +153,7 @@ class ModelVersion(lineage_node.LineageNode):
150
153
  self._model_name = model_name
151
154
  self._version_name = version_name
152
155
  self._functions = self._get_functions()
156
+ self._model_spec = None
153
157
  super(cls, cls).__init__(
154
158
  self,
155
159
  session=model_ops._session,
@@ -437,6 +441,26 @@ class ModelVersion(lineage_node.LineageNode):
437
441
  """
438
442
  return self._functions
439
443
 
444
+ def _get_model_spec(self, statement_params: Optional[dict[str, Any]] = None) -> model_meta_schema.ModelMetadataDict:
445
+ """Fetch and cache the model spec for this model version.
446
+
447
+ Args:
448
+ statement_params: Optional dictionary of statement parameters to include
449
+ in the SQL command to fetch the model spec.
450
+
451
+ Returns:
452
+ The model spec as a dictionary for this model version.
453
+ """
454
+ if self._model_spec is None:
455
+ self._model_spec = self._model_ops._fetch_model_spec(
456
+ database_name=None,
457
+ schema_name=None,
458
+ model_name=self._model_name,
459
+ version_name=self._version_name,
460
+ statement_params=statement_params,
461
+ )
462
+ return self._model_spec
463
+
440
464
  @overload
441
465
  def run(
442
466
  self,
@@ -531,6 +555,8 @@ class ModelVersion(lineage_node.LineageNode):
531
555
  statement_params=statement_params,
532
556
  )
533
557
  else:
558
+ explain_case_sensitive = self._determine_explain_case_sensitivity(target_function_info, statement_params)
559
+
534
560
  return self._model_ops.invoke_method(
535
561
  method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
536
562
  method_function_type=target_function_info["target_method_function_type"],
@@ -544,8 +570,20 @@ class ModelVersion(lineage_node.LineageNode):
544
570
  partition_column=partition_column,
545
571
  statement_params=statement_params,
546
572
  is_partitioned=target_function_info["is_partitioned"],
573
+ explain_case_sensitive=explain_case_sensitive,
547
574
  )
548
575
 
576
+ def _determine_explain_case_sensitivity(
577
+ self,
578
+ target_function_info: model_manifest_schema.ModelFunctionInfo,
579
+ statement_params: Optional[dict[str, Any]] = None,
580
+ ) -> bool:
581
+ model_spec = self._get_model_spec(statement_params)
582
+ method_options = model_spec.get("method_options", {})
583
+ return model_method_utils.determine_explain_case_sensitive_from_method_options(
584
+ method_options, target_function_info["name"]
585
+ )
586
+
549
587
  @telemetry.send_api_usage_telemetry(
550
588
  project=_TELEMETRY_PROJECT,
551
589
  subproject=_TELEMETRY_SUBPROJECT,
@@ -803,13 +841,7 @@ class ModelVersion(lineage_node.LineageNode):
803
841
  ValueError: If the model is not a HuggingFace text-generation model.
804
842
  """
805
843
  # Fetch model spec
806
- model_spec = self._model_ops._fetch_model_spec(
807
- database_name=None,
808
- schema_name=None,
809
- model_name=self._model_name,
810
- version_name=self._version_name,
811
- statement_params=statement_params,
812
- )
844
+ model_spec = self._get_model_spec(statement_params)
813
845
 
814
846
  # Check if model_type is huggingface_pipeline
815
847
  model_type = model_spec.get("model_type")
@@ -952,6 +952,7 @@ class ModelOperator:
952
952
  partition_column: Optional[sql_identifier.SqlIdentifier] = None,
953
953
  statement_params: Optional[dict[str, str]] = None,
954
954
  is_partitioned: Optional[bool] = None,
955
+ explain_case_sensitive: bool = False,
955
956
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
956
957
  ...
957
958
 
@@ -967,6 +968,7 @@ class ModelOperator:
967
968
  service_name: sql_identifier.SqlIdentifier,
968
969
  strict_input_validation: bool = False,
969
970
  statement_params: Optional[dict[str, str]] = None,
971
+ explain_case_sensitive: bool = False,
970
972
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
971
973
  ...
972
974
 
@@ -986,6 +988,7 @@ class ModelOperator:
986
988
  partition_column: Optional[sql_identifier.SqlIdentifier] = None,
987
989
  statement_params: Optional[dict[str, str]] = None,
988
990
  is_partitioned: Optional[bool] = None,
991
+ explain_case_sensitive: bool = False,
989
992
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
990
993
  identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
991
994
 
@@ -1068,6 +1071,7 @@ class ModelOperator:
1068
1071
  version_name=version_name,
1069
1072
  statement_params=statement_params,
1070
1073
  is_partitioned=is_partitioned or False,
1074
+ explain_case_sensitive=explain_case_sensitive,
1071
1075
  )
1072
1076
 
1073
1077
  if keep_order:
@@ -438,6 +438,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
438
438
  partition_column: Optional[sql_identifier.SqlIdentifier],
439
439
  statement_params: Optional[dict[str, Any]] = None,
440
440
  is_partitioned: bool = True,
441
+ explain_case_sensitive: bool = False,
441
442
  ) -> dataframe.DataFrame:
442
443
  with_statements = []
443
444
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -505,7 +506,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
505
506
  cols_to_drop = []
506
507
 
507
508
  for output_name, output_type, output_col_name in returns:
508
- output_identifier = sql_identifier.SqlIdentifier(output_name).identifier()
509
+ case_sensitive = "explain" in method_name.resolved().lower() and explain_case_sensitive
510
+ output_identifier = sql_identifier.SqlIdentifier(output_name, case_sensitive=case_sensitive).identifier()
509
511
  if output_identifier != output_col_name:
510
512
  cols_to_drop.append(output_identifier)
511
513
  output_cols.append(F.col(output_identifier).astype(output_type))
@@ -11,6 +11,7 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest_sch
11
11
  from snowflake.ml.model._model_composer.model_method import (
12
12
  constants,
13
13
  function_generator,
14
+ utils,
14
15
  )
15
16
  from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
16
17
  from snowflake.ml.model.volatility import Volatility
@@ -34,9 +35,13 @@ def get_model_method_options_from_options(
34
35
  options: type_hints.ModelSaveOption, target_method: str
35
36
  ) -> ModelMethodOptions:
36
37
  default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
38
+ method_option = options.get("method_options", {}).get(target_method, {})
39
+ case_sensitive = method_option.get("case_sensitive", False)
37
40
  if target_method == "explain":
38
41
  default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
39
- method_option = options.get("method_options", {}).get(target_method, {})
42
+ case_sensitive = utils.determine_explain_case_sensitive_from_method_options(
43
+ options.get("method_options", {}), target_method
44
+ )
40
45
  global_function_type = options.get("function_type", default_function_type)
41
46
  function_type = method_option.get("function_type", global_function_type)
42
47
  if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
@@ -48,7 +53,7 @@ def get_model_method_options_from_options(
48
53
 
49
54
  # Only include volatility if explicitly provided in method options
50
55
  result: ModelMethodOptions = ModelMethodOptions(
51
- case_sensitive=method_option.get("case_sensitive", False),
56
+ case_sensitive=case_sensitive,
52
57
  function_type=function_type,
53
58
  )
54
59
  if resolved_volatility:
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Mapping, Optional
4
+
5
+
6
+ def determine_explain_case_sensitive_from_method_options(
7
+ method_options: Mapping[str, Optional[Mapping[str, Any]]],
8
+ target_method: str,
9
+ ) -> bool:
10
+ """Determine explain method case sensitivity from related predict methods.
11
+
12
+ Args:
13
+ method_options: Mapping from method name to its options. Each option may
14
+ contain ``"case_sensitive"`` to indicate SQL identifier sensitivity.
15
+ target_method: The target method name being resolved (e.g., an ``explain_*``
16
+ method).
17
+
18
+ Returns:
19
+ True if the explain method should be treated as case sensitive; otherwise False.
20
+ """
21
+ if "explain" not in target_method:
22
+ return False
23
+ predict_priority_methods = ["predict_proba", "predict", "predict_log_proba"]
24
+ for src_method in predict_priority_methods:
25
+ src_opts = method_options.get(src_method)
26
+ if src_opts is not None:
27
+ return bool(src_opts.get("case_sensitive", False))
28
+ return False
@@ -240,14 +240,31 @@ class ModelEnv:
240
240
  self._conda_dependencies[channel].remove(spec)
241
241
 
242
242
  def generate_env_for_cuda(self) -> None:
243
+
244
+ # Insert py-xgboost-gpu only for XGBoost versions < 3.0.0
243
245
  xgboost_spec = env_utils.find_dep_spec(
244
- self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
246
+ self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=False
245
247
  )
246
248
  if xgboost_spec:
247
- self.include_if_absent(
248
- [ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
249
- check_local_version=False,
250
- )
249
+ # Only handle explicitly pinned versions. Insert GPU variant iff pinned major < 3.
250
+ pinned_major: Optional[int] = None
251
+ for spec in xgboost_spec.specifier:
252
+ if spec.operator in ("==", "===", ">", ">="):
253
+ try:
254
+ pinned_major = version.parse(spec.version).major
255
+ except version.InvalidVersion:
256
+ pinned_major = None
257
+ break
258
+
259
+ if pinned_major is not None and pinned_major < 3:
260
+ xgboost_spec = env_utils.find_dep_spec(
261
+ self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
262
+ )
263
+ if xgboost_spec:
264
+ self.include_if_absent(
265
+ [ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
266
+ check_local_version=False,
267
+ )
251
268
 
252
269
  tf_spec = env_utils.find_dep_spec(
253
270
  self._conda_dependencies, self._pip_requirements, conda_pkg_name="tensorflow", remove_spec=True