snowflake-ml-python 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
  2. snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
  3. snowflake/ml/_internal/telemetry.py +3 -2
  4. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
  5. snowflake/ml/experiment/callback/keras.py +3 -0
  6. snowflake/ml/experiment/callback/lightgbm.py +3 -0
  7. snowflake/ml/experiment/callback/xgboost.py +3 -0
  8. snowflake/ml/experiment/experiment_tracking.py +19 -7
  9. snowflake/ml/feature_store/feature_store.py +236 -61
  10. snowflake/ml/jobs/__init__.py +4 -0
  11. snowflake/ml/jobs/_interop/__init__.py +0 -0
  12. snowflake/ml/jobs/_interop/data_utils.py +124 -0
  13. snowflake/ml/jobs/_interop/dto_schema.py +95 -0
  14. snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
  15. snowflake/ml/jobs/_interop/legacy.py +225 -0
  16. snowflake/ml/jobs/_interop/protocols.py +471 -0
  17. snowflake/ml/jobs/_interop/results.py +51 -0
  18. snowflake/ml/jobs/_interop/utils.py +144 -0
  19. snowflake/ml/jobs/_utils/constants.py +16 -2
  20. snowflake/ml/jobs/_utils/feature_flags.py +37 -5
  21. snowflake/ml/jobs/_utils/payload_utils.py +8 -2
  22. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
  23. snowflake/ml/jobs/_utils/spec_utils.py +2 -1
  24. snowflake/ml/jobs/_utils/stage_utils.py +4 -0
  25. snowflake/ml/jobs/_utils/types.py +15 -0
  26. snowflake/ml/jobs/job.py +186 -40
  27. snowflake/ml/jobs/manager.py +48 -39
  28. snowflake/ml/model/__init__.py +19 -0
  29. snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
  30. snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
  31. snowflake/ml/model/_client/model/model_version_impl.py +168 -18
  32. snowflake/ml/model/_client/ops/model_ops.py +4 -0
  33. snowflake/ml/model/_client/ops/service_ops.py +3 -0
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  36. snowflake/ml/model/_client/sql/model_version.py +3 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +11 -3
  39. snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
  40. snowflake/ml/model/_packager/model_env/model_env.py +22 -5
  41. snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
  42. snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
  43. snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
  44. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +7 -0
  45. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  46. snowflake/ml/model/type_hints.py +16 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
  48. snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
  49. snowflake/ml/version.py +1 -1
  50. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +50 -4
  51. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +54 -45
  52. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
  53. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
  54. {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ from typing import Any, Callable, Optional, Union, overload
7
7
 
8
8
  import pandas as pd
9
9
 
10
+ from snowflake import snowpark
10
11
  from snowflake.ml import jobs
11
12
  from snowflake.ml._internal import telemetry
12
13
  from snowflake.ml._internal.utils import sql_identifier
@@ -19,7 +20,9 @@ from snowflake.ml.model._client.model import (
19
20
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
20
21
  from snowflake.ml.model._model_composer import model_composer
21
22
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
23
+ from snowflake.ml.model._model_composer.model_method import utils as model_method_utils
22
24
  from snowflake.ml.model._packager.model_handlers import snowmlmodel
25
+ from snowflake.ml.model._packager.model_meta import model_meta_schema
23
26
  from snowflake.snowpark import Session, async_job, dataframe
24
27
 
25
28
  _TELEMETRY_PROJECT = "MLOps"
@@ -41,6 +44,7 @@ class ModelVersion(lineage_node.LineageNode):
41
44
  _model_name: sql_identifier.SqlIdentifier
42
45
  _version_name: sql_identifier.SqlIdentifier
43
46
  _functions: list[model_manifest_schema.ModelFunctionInfo]
47
+ _model_spec: Optional[model_meta_schema.ModelMetadataDict]
44
48
 
45
49
  def __init__(self) -> None:
46
50
  raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
@@ -150,6 +154,7 @@ class ModelVersion(lineage_node.LineageNode):
150
154
  self._model_name = model_name
151
155
  self._version_name = version_name
152
156
  self._functions = self._get_functions()
157
+ self._model_spec = None
153
158
  super(cls, cls).__init__(
154
159
  self,
155
160
  session=model_ops._session,
@@ -437,6 +442,26 @@ class ModelVersion(lineage_node.LineageNode):
437
442
  """
438
443
  return self._functions
439
444
 
445
+ def _get_model_spec(self, statement_params: Optional[dict[str, Any]] = None) -> model_meta_schema.ModelMetadataDict:
446
+ """Fetch and cache the model spec for this model version.
447
+
448
+ Args:
449
+ statement_params: Optional dictionary of statement parameters to include
450
+ in the SQL command to fetch the model spec.
451
+
452
+ Returns:
453
+ The model spec as a dictionary for this model version.
454
+ """
455
+ if self._model_spec is None:
456
+ self._model_spec = self._model_ops._fetch_model_spec(
457
+ database_name=None,
458
+ schema_name=None,
459
+ model_name=self._model_name,
460
+ version_name=self._version_name,
461
+ statement_params=statement_params,
462
+ )
463
+ return self._model_spec
464
+
440
465
  @overload
441
466
  def run(
442
467
  self,
@@ -531,6 +556,8 @@ class ModelVersion(lineage_node.LineageNode):
531
556
  statement_params=statement_params,
532
557
  )
533
558
  else:
559
+ explain_case_sensitive = self._determine_explain_case_sensitivity(target_function_info, statement_params)
560
+
534
561
  return self._model_ops.invoke_method(
535
562
  method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
536
563
  method_function_type=target_function_info["target_method_function_type"],
@@ -544,8 +571,20 @@ class ModelVersion(lineage_node.LineageNode):
544
571
  partition_column=partition_column,
545
572
  statement_params=statement_params,
546
573
  is_partitioned=target_function_info["is_partitioned"],
574
+ explain_case_sensitive=explain_case_sensitive,
547
575
  )
548
576
 
577
+ def _determine_explain_case_sensitivity(
578
+ self,
579
+ target_function_info: model_manifest_schema.ModelFunctionInfo,
580
+ statement_params: Optional[dict[str, Any]] = None,
581
+ ) -> bool:
582
+ model_spec = self._get_model_spec(statement_params)
583
+ method_options = model_spec.get("method_options", {})
584
+ return model_method_utils.determine_explain_case_sensitive_from_method_options(
585
+ method_options, target_function_info["name"]
586
+ )
587
+
549
588
  @telemetry.send_api_usage_telemetry(
550
589
  project=_TELEMETRY_PROJECT,
551
590
  subproject=_TELEMETRY_SUBPROJECT,
@@ -555,7 +594,8 @@ class ModelVersion(lineage_node.LineageNode):
555
594
  "job_spec",
556
595
  ],
557
596
  )
558
- def _run_batch(
597
+ @snowpark._internal.utils.private_preview(version="1.18.0")
598
+ def run_batch(
559
599
  self,
560
600
  *,
561
601
  compute_pool: str,
@@ -563,6 +603,68 @@ class ModelVersion(lineage_node.LineageNode):
563
603
  output_spec: batch_inference_specs.OutputSpec,
564
604
  job_spec: Optional[batch_inference_specs.JobSpec] = None,
565
605
  ) -> jobs.MLJob[Any]:
606
+ """Execute batch inference on datasets as an SPCS job.
607
+
608
+ Args:
609
+ compute_pool (str): Name of the compute pool to use for building the image containers and batch
610
+ inference execution.
611
+ input_spec (dataframe.DataFrame): Snowpark DataFrame containing the input data for inference.
612
+ The DataFrame should contain all required features for model prediction and passthrough columns.
613
+ output_spec (batch_inference_specs.OutputSpec): Configuration for where and how to save
614
+ the inference results. Specifies the stage location and file handling behavior.
615
+ job_spec (Optional[batch_inference_specs.JobSpec]): Optional configuration for job
616
+ execution parameters such as compute resources, worker counts, and job naming.
617
+ If None, default values will be used.
618
+
619
+ Returns:
620
+ jobs.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
621
+ lifecycle.
622
+
623
+ Raises:
624
+ ValueError: If warehouse is not set in job_spec and no current warehouse is available.
625
+ RuntimeError: If the input_spec cannot be processed or written to the staging location.
626
+
627
+ Example:
628
+ >>> # Prepare input data - Example 1: From a table
629
+ >>> input_df = session.table("my_input_table")
630
+ >>>
631
+ >>> # Prepare input data - Example 2: From a SQL query
632
+ >>> input_df = session.sql(
633
+ ... "SELECT id, feature_1, feature_2 FROM feature_table WHERE feature_1 > 100"
634
+ ... )
635
+ >>>
636
+ >>> # Prepare input data - Example 3: From Parquet files in a stage
637
+ >>> input_df = session.read.option("pattern", ".*\\.parquet").parquet(
638
+ ... "@my_stage/input_data/"
639
+ ... ).select("id", "feature_1", "feature_2")
640
+ >>>
641
+ >>> # Configure output location
642
+ >>> output_spec = OutputSpec(
643
+ ... stage_location='@My_DB.PUBLIC.MY_STAGE/someth/path/',
644
+ ... mode=SaveMode.OVERWRITE
645
+ ... )
646
+ >>>
647
+ >>> # Configure job parameters
648
+ >>> job_spec = JobSpec(
649
+ ... job_name="my_batch_inference",
650
+ ... num_workers=4,
651
+ ... cpu_requests="2",
652
+ ... memory_requests="8Gi"
653
+ ... )
654
+ >>>
655
+ >>> # Run batch inference
656
+ >>> job = model_version.run_batch(
657
+ ... compute_pool="my_compute_pool",
658
+ ... input_spec=input_df,
659
+ ... output_spec=output_spec,
660
+ ... job_spec=job_spec
661
+ ... )
662
+
663
+ Note:
664
+ This method is currently in private preview and requires Snowflake version 1.18.0 or later.
665
+ The input data is temporarily stored in the output stage location under /_temporary before
666
+ inference execution.
667
+ """
566
668
  statement_params = telemetry.get_statement_params(
567
669
  project=_TELEMETRY_PROJECT,
568
670
  subproject=_TELEMETRY_SUBPROJECT,
@@ -789,6 +891,51 @@ class ModelVersion(lineage_node.LineageNode):
789
891
  version_name=sql_identifier.SqlIdentifier(version),
790
892
  )
791
893
 
894
+ def _can_run_on_gpu(
895
+ self,
896
+ statement_params: Optional[dict[str, Any]] = None,
897
+ ) -> bool:
898
+ """Check if the model has GPU runtime support.
899
+
900
+ Args:
901
+ statement_params: Optional dictionary of statement parameters to include
902
+ in the SQL command to fetch model spec.
903
+
904
+ Returns:
905
+ True if the model has GPU runtime configured, False otherwise.
906
+ """
907
+ # Fetch model spec
908
+ model_spec = self._get_model_spec(statement_params)
909
+
910
+ # Check if runtimes section exists and has gpu runtime
911
+ runtimes = model_spec.get("runtimes", {})
912
+ return "gpu" in runtimes
913
+
914
+ def _throw_error_if_gpu_is_not_supported(
915
+ self,
916
+ gpu_requests: Optional[Union[str, int]] = None,
917
+ statement_params: Optional[dict[str, Any]] = None,
918
+ ) -> None:
919
+ """Check if the model has GPU runtime support.
920
+
921
+ Args:
922
+ gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
923
+ if None.
924
+ statement_params: Optional dictionary of statement parameters to include
925
+ in the SQL command to fetch model spec.
926
+
927
+ Raises:
928
+ ValueError: If the model does not have GPU runtime support.
929
+ """
930
+ if gpu_requests is not None and not self._can_run_on_gpu(statement_params):
931
+ raise ValueError(
932
+ f"GPU resources requested (gpu_requests={gpu_requests}), but the model "
933
+ f"{self.fully_qualified_model_name} version {self.version_name} does not have GPU runtime support. "
934
+ "Please ensure the model was logged with GPU runtime configuration or do not provide gpu_requests. "
935
+ "To log the model with GPU runtime configuration, provide `cuda_version` in the `options` while calling"
936
+ " the `log_model` function."
937
+ )
938
+
792
939
  def _check_huggingface_text_generation_model(
793
940
  self,
794
941
  statement_params: Optional[dict[str, Any]] = None,
@@ -803,13 +950,7 @@ class ModelVersion(lineage_node.LineageNode):
803
950
  ValueError: If the model is not a HuggingFace text-generation model.
804
951
  """
805
952
  # Fetch model spec
806
- model_spec = self._model_ops._fetch_model_spec(
807
- database_name=None,
808
- schema_name=None,
809
- model_name=self._model_name,
810
- version_name=self._version_name,
811
- statement_params=statement_params,
812
- )
953
+ model_spec = self._get_model_spec(statement_params)
813
954
 
814
955
  # Check if model_type is huggingface_pipeline
815
956
  model_type = model_spec.get("model_type")
@@ -894,9 +1035,10 @@ class ModelVersion(lineage_node.LineageNode):
894
1035
  When it is ``False``, this function executes the underlying service creation asynchronously
895
1036
  and returns an :class:`AsyncJob`.
896
1037
  experimental_options: Experimental options for the service creation with custom inference engine.
897
- Currently, only `inference_engine` and `inference_engine_args_override` are supported.
1038
+ Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
898
1039
  `inference_engine` is the name of the inference engine to use.
899
1040
  `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1041
+ `autocapture` is a boolean to enable/disable inference table.
900
1042
  """
901
1043
  ...
902
1044
 
@@ -952,9 +1094,10 @@ class ModelVersion(lineage_node.LineageNode):
952
1094
  When it is ``False``, this function executes the underlying service creation asynchronously
953
1095
  and returns an :class:`AsyncJob`.
954
1096
  experimental_options: Experimental options for the service creation with custom inference engine.
955
- Currently, only `inference_engine` and `inference_engine_args_override` are supported.
1097
+ Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
956
1098
  `inference_engine` is the name of the inference engine to use.
957
1099
  `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1100
+ `autocapture` is a boolean to enable/disable inference table.
958
1101
  """
959
1102
  ...
960
1103
 
@@ -1027,21 +1170,20 @@ class ModelVersion(lineage_node.LineageNode):
1027
1170
  When it is False, this function executes the underlying service creation asynchronously
1028
1171
  and returns an AsyncJob.
1029
1172
  experimental_options: Experimental options for the service creation with custom inference engine.
1030
- Currently, only `inference_engine` and `inference_engine_args_override` are supported.
1173
+ Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
1031
1174
  `inference_engine` is the name of the inference engine to use.
1032
1175
  `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1176
+ `autocapture` is a boolean to enable/disable inference table.
1033
1177
 
1034
1178
 
1035
1179
  Raises:
1036
- ValueError: Illegal external access integration arguments.
1180
+ ValueError: Illegal external access integration arguments, or if GPU resources are requested
1181
+ but the model does not have GPU runtime support.
1037
1182
  exceptions.SnowparkSQLException: if service already exists.
1038
1183
 
1039
1184
  Returns:
1040
1185
  If `block=True`, return result information about service creation from server.
1041
1186
  Otherwise, return the service creation AsyncJob.
1042
-
1043
- Raises:
1044
- ValueError: Illegal external access integration arguments.
1045
1187
  """
1046
1188
  statement_params = telemetry.get_statement_params(
1047
1189
  project=_TELEMETRY_PROJECT,
@@ -1064,12 +1206,16 @@ class ModelVersion(lineage_node.LineageNode):
1064
1206
 
1065
1207
  service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
1066
1208
 
1067
- # Check if model is HuggingFace text-generation before doing inference engine checks
1068
- if experimental_options:
1069
- self._check_huggingface_text_generation_model(statement_params)
1209
+ # Validate GPU support if GPU resources are requested
1210
+ self._throw_error_if_gpu_is_not_supported(gpu_requests, statement_params)
1070
1211
 
1071
1212
  inference_engine_args = inference_engine_utils._get_inference_engine_args(experimental_options)
1072
1213
 
1214
+ # Check if model is HuggingFace text-generation before doing inference engine checks
1215
+ # Only validate if inference engine is actually specified
1216
+ if inference_engine_args is not None:
1217
+ self._check_huggingface_text_generation_model(statement_params)
1218
+
1073
1219
  # Enrich inference engine args if inference engine is specified
1074
1220
  if inference_engine_args is not None:
1075
1221
  inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
@@ -1077,6 +1223,9 @@ class ModelVersion(lineage_node.LineageNode):
1077
1223
  gpu_requests,
1078
1224
  )
1079
1225
 
1226
+ # Extract autocapture from experimental_options
1227
+ autocapture = experimental_options.get("autocapture") if experimental_options else None
1228
+
1080
1229
  from snowflake.ml.model import event_handler
1081
1230
  from snowflake.snowpark import exceptions
1082
1231
 
@@ -1116,6 +1265,7 @@ class ModelVersion(lineage_node.LineageNode):
1116
1265
  statement_params=statement_params,
1117
1266
  progress_status=status,
1118
1267
  inference_engine_args=inference_engine_args,
1268
+ autocapture=autocapture,
1119
1269
  )
1120
1270
  status.update(label="Model service created successfully", state="complete", expanded=False)
1121
1271
  return result
@@ -952,6 +952,7 @@ class ModelOperator:
952
952
  partition_column: Optional[sql_identifier.SqlIdentifier] = None,
953
953
  statement_params: Optional[dict[str, str]] = None,
954
954
  is_partitioned: Optional[bool] = None,
955
+ explain_case_sensitive: bool = False,
955
956
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
956
957
  ...
957
958
 
@@ -967,6 +968,7 @@ class ModelOperator:
967
968
  service_name: sql_identifier.SqlIdentifier,
968
969
  strict_input_validation: bool = False,
969
970
  statement_params: Optional[dict[str, str]] = None,
971
+ explain_case_sensitive: bool = False,
970
972
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
971
973
  ...
972
974
 
@@ -986,6 +988,7 @@ class ModelOperator:
986
988
  partition_column: Optional[sql_identifier.SqlIdentifier] = None,
987
989
  statement_params: Optional[dict[str, str]] = None,
988
990
  is_partitioned: Optional[bool] = None,
991
+ explain_case_sensitive: bool = False,
989
992
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
990
993
  identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
991
994
 
@@ -1068,6 +1071,7 @@ class ModelOperator:
1068
1071
  version_name=version_name,
1069
1072
  statement_params=statement_params,
1070
1073
  is_partitioned=is_partitioned or False,
1074
+ explain_case_sensitive=explain_case_sensitive,
1071
1075
  )
1072
1076
 
1073
1077
  if keep_order:
@@ -206,6 +206,8 @@ class ServiceOperator:
206
206
  hf_model_args: Optional[HFModelArgs] = None,
207
207
  # inference engine model
208
208
  inference_engine_args: Optional[InferenceEngineArgs] = None,
209
+ # inference table
210
+ autocapture: Optional[bool] = None,
209
211
  ) -> Union[str, async_job.AsyncJob]:
210
212
 
211
213
  # Generate operation ID for this deployment
@@ -261,6 +263,7 @@ class ServiceOperator:
261
263
  gpu=gpu_requests,
262
264
  num_workers=num_workers,
263
265
  max_batch_rows=max_batch_rows,
266
+ autocapture=autocapture,
264
267
  )
265
268
  if hf_model_args:
266
269
  # hf model
@@ -146,6 +146,7 @@ class ModelDeploymentSpec:
146
146
  gpu: Optional[Union[str, int]] = None,
147
147
  num_workers: Optional[int] = None,
148
148
  max_batch_rows: Optional[int] = None,
149
+ autocapture: Optional[bool] = None,
149
150
  ) -> "ModelDeploymentSpec":
150
151
  """Add service specification to the deployment spec.
151
152
 
@@ -161,6 +162,7 @@ class ModelDeploymentSpec:
161
162
  gpu: GPU requirement.
162
163
  num_workers: Number of workers.
163
164
  max_batch_rows: Maximum batch rows for inference.
165
+ autocapture: Whether to enable inference table.
164
166
 
165
167
  Raises:
166
168
  ValueError: If a job spec already exists.
@@ -186,6 +188,7 @@ class ModelDeploymentSpec:
186
188
  compute_pool=inference_compute_pool_name.identifier(),
187
189
  ingress_enabled=ingress_enabled,
188
190
  max_instances=max_instances,
191
+ autocapture=autocapture,
189
192
  **self._inference_spec,
190
193
  )
191
194
  return self
@@ -32,6 +32,7 @@ class Service(BaseModel):
32
32
  gpu: Optional[str] = None
33
33
  num_workers: Optional[int] = None
34
34
  max_batch_rows: Optional[int] = None
35
+ autocapture: Optional[bool] = None
35
36
  inference_engine_spec: Optional[InferenceEngineSpec] = None
36
37
 
37
38
 
@@ -438,6 +438,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
438
438
  partition_column: Optional[sql_identifier.SqlIdentifier],
439
439
  statement_params: Optional[dict[str, Any]] = None,
440
440
  is_partitioned: bool = True,
441
+ explain_case_sensitive: bool = False,
441
442
  ) -> dataframe.DataFrame:
442
443
  with_statements = []
443
444
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -505,7 +506,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
505
506
  cols_to_drop = []
506
507
 
507
508
  for output_name, output_type, output_col_name in returns:
508
- output_identifier = sql_identifier.SqlIdentifier(output_name).identifier()
509
+ case_sensitive = "explain" in method_name.resolved().lower() and explain_case_sensitive
510
+ output_identifier = sql_identifier.SqlIdentifier(output_name, case_sensitive=case_sensitive).identifier()
509
511
  if output_identifier != output_col_name:
510
512
  cols_to_drop.append(output_identifier)
511
513
  output_cols.append(F.col(output_identifier).astype(output_type))
@@ -87,7 +87,9 @@ class ModelManifest:
87
87
  model_meta_schema.FunctionProperties.PARTITIONED.value, False
88
88
  ),
89
89
  wide_input=len(model_meta.signatures[target_method].inputs) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT,
90
- options=model_method.get_model_method_options_from_options(options, target_method),
90
+ options=model_method.get_model_method_options_from_options(
91
+ options, target_method, model_meta.model_type
92
+ ),
91
93
  )
92
94
 
93
95
  self.methods.append(method)
@@ -11,6 +11,7 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest_sch
11
11
  from snowflake.ml.model._model_composer.model_method import (
12
12
  constants,
13
13
  function_generator,
14
+ utils,
14
15
  )
15
16
  from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
16
17
  from snowflake.ml.model.volatility import Volatility
@@ -31,12 +32,19 @@ class ModelMethodOptions(TypedDict):
31
32
 
32
33
 
33
34
  def get_model_method_options_from_options(
34
- options: type_hints.ModelSaveOption, target_method: str
35
+ options: type_hints.ModelSaveOption, target_method: str, model_type: Optional[str] = None
35
36
  ) -> ModelMethodOptions:
36
37
  default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
38
+ method_option = options.get("method_options", {}).get(target_method, {})
39
+ case_sensitive = method_option.get("case_sensitive", False)
37
40
  if target_method == "explain":
38
41
  default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
39
- method_option = options.get("method_options", {}).get(target_method, {})
42
+ case_sensitive = utils.determine_explain_case_sensitive_from_method_options(
43
+ options.get("method_options", {}), target_method
44
+ )
45
+ elif model_type == "prophet":
46
+ # Prophet models always require TABLE_FUNCTION because they need entire time series context
47
+ default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
40
48
  global_function_type = options.get("function_type", default_function_type)
41
49
  function_type = method_option.get("function_type", global_function_type)
42
50
  if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
@@ -48,7 +56,7 @@ def get_model_method_options_from_options(
48
56
 
49
57
  # Only include volatility if explicitly provided in method options
50
58
  result: ModelMethodOptions = ModelMethodOptions(
51
- case_sensitive=method_option.get("case_sensitive", False),
59
+ case_sensitive=case_sensitive,
52
60
  function_type=function_type,
53
61
  )
54
62
  if resolved_volatility:
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Mapping, Optional
4
+
5
+
6
+ def determine_explain_case_sensitive_from_method_options(
7
+ method_options: Mapping[str, Optional[Mapping[str, Any]]],
8
+ target_method: str,
9
+ ) -> bool:
10
+ """Determine explain method case sensitivity from related predict methods.
11
+
12
+ Args:
13
+ method_options: Mapping from method name to its options. Each option may
14
+ contain ``"case_sensitive"`` to indicate SQL identifier sensitivity.
15
+ target_method: The target method name being resolved (e.g., an ``explain_*``
16
+ method).
17
+
18
+ Returns:
19
+ True if the explain method should be treated as case sensitive; otherwise False.
20
+ """
21
+ if "explain" not in target_method:
22
+ return False
23
+ predict_priority_methods = ["predict_proba", "predict", "predict_log_proba"]
24
+ for src_method in predict_priority_methods:
25
+ src_opts = method_options.get(src_method)
26
+ if src_opts is not None:
27
+ return bool(src_opts.get("case_sensitive", False))
28
+ return False
@@ -240,14 +240,31 @@ class ModelEnv:
240
240
  self._conda_dependencies[channel].remove(spec)
241
241
 
242
242
  def generate_env_for_cuda(self) -> None:
243
+
244
+ # Insert py-xgboost-gpu only for XGBoost versions < 3.0.0
243
245
  xgboost_spec = env_utils.find_dep_spec(
244
- self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
246
+ self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=False
245
247
  )
246
248
  if xgboost_spec:
247
- self.include_if_absent(
248
- [ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
249
- check_local_version=False,
250
- )
249
+ # Only handle explicitly pinned versions. Insert GPU variant iff pinned major < 3.
250
+ pinned_major: Optional[int] = None
251
+ for spec in xgboost_spec.specifier:
252
+ if spec.operator in ("==", "===", ">", ">="):
253
+ try:
254
+ pinned_major = version.parse(spec.version).major
255
+ except version.InvalidVersion:
256
+ pinned_major = None
257
+ break
258
+
259
+ if pinned_major is not None and pinned_major < 3:
260
+ xgboost_spec = env_utils.find_dep_spec(
261
+ self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
262
+ )
263
+ if xgboost_spec:
264
+ self.include_if_absent(
265
+ [ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
266
+ check_local_version=False,
267
+ )
251
268
 
252
269
  tf_spec = env_utils.find_dep_spec(
253
270
  self._conda_dependencies, self._pip_requirements, conda_pkg_name="tensorflow", remove_spec=True
@@ -305,3 +305,73 @@ def get_default_cuda_version() -> str:
305
305
 
306
306
  return torch.version.cuda or model_env.DEFAULT_CUDA_VERSION
307
307
  return model_env.DEFAULT_CUDA_VERSION
308
+
309
+
310
+ def normalize_column_name(column_name: str) -> str:
311
+ """Normalize a column name to be a valid unquoted Snowflake SQL identifier.
312
+
313
+ Converts column names with spaces and special characters (e.g., "Christmas Day")
314
+ into valid lowercase unquoted SQL identifiers (e.g., "christmas_day") that can be used
315
+ without quotes in SQL queries. This follows Snowflake's unquoted identifier rules:
316
+ https://docs.snowflake.com/en/sql-reference/identifiers-syntax
317
+
318
+ The normalization approach is preferred over quoted identifiers because:
319
+ - Unquoted identifiers are simpler and more readable
320
+ - They don't require special handling in SQL contexts
321
+ - They avoid case-sensitivity complications
322
+ - Lowercase convention improves readability and follows Python/pandas conventions
323
+
324
+ This utility is useful for model handlers that need to ensure output column names
325
+ from models (e.g., Prophet holiday columns, feature names) are SQL-safe.
326
+
327
+ Args:
328
+ column_name: Original column name (may contain spaces, special chars, etc.)
329
+
330
+ Returns:
331
+ Normalized lowercase column name that is a valid unquoted SQL identifier matching
332
+ the pattern [a-z_][a-z0-9_$]*
333
+
334
+ Examples:
335
+ >>> normalize_column_name_for_snowflake("Christmas Day")
336
+ 'christmas_day'
337
+ >>> normalize_column_name_for_snowflake("New Year's Day")
338
+ 'new_year_s_day'
339
+ >>> normalize_column_name_for_snowflake("__private")
340
+ '__private'
341
+ >>> normalize_column_name_for_snowflake("2023_data")
342
+ '_2023_data'
343
+ """
344
+ import re
345
+
346
+ # Convert to lowercase for readability and consistency
347
+ normalized = column_name.lower()
348
+
349
+ # Replace spaces and special characters with underscores
350
+ # Keep only alphanumeric characters, underscores, and dollar signs (valid in unquoted identifiers)
351
+ normalized = re.sub(r"[^a-z0-9_$]", "_", normalized)
352
+
353
+ # Collapse consecutive underscores while preserving leading underscores
354
+ # This handles cases like "A B" → "a__b" → "a_b" while preserving "__name" → "__name"
355
+ if len(normalized) > 1:
356
+ # Count and preserve leading underscores, collapse the rest
357
+ leading_underscores = len(normalized) - len(normalized.lstrip("_"))
358
+ rest_of_string = normalized[leading_underscores:]
359
+ rest_collapsed = re.sub(r"_+", "_", rest_of_string)
360
+ normalized = "_" * leading_underscores + rest_collapsed
361
+
362
+ if normalized == "_":
363
+ return "_"
364
+
365
+ normalized = normalized.rstrip("_")
366
+
367
+ # Ensure it starts with a letter or underscore (SQL requirement)
368
+ # Unquoted identifiers must match: [a-z_][a-z0-9_$]*
369
+ if normalized and normalized[0].isdigit():
370
+ normalized = "_" + normalized
371
+
372
+ # If normalization resulted in empty string, use a default
373
+ # (This happens when input was only underscores like "___")
374
+ if not normalized:
375
+ normalized = "column"
376
+
377
+ return normalized