snowflake-ml-python 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +0 -4
  2. snowflake/ml/feature_store/__init__.py +2 -0
  3. snowflake/ml/feature_store/aggregation.py +367 -0
  4. snowflake/ml/feature_store/feature.py +366 -0
  5. snowflake/ml/feature_store/feature_store.py +234 -20
  6. snowflake/ml/feature_store/feature_view.py +189 -4
  7. snowflake/ml/feature_store/metadata_manager.py +425 -0
  8. snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
  9. snowflake/ml/jobs/__init__.py +2 -0
  10. snowflake/ml/jobs/_utils/constants.py +1 -0
  11. snowflake/ml/jobs/_utils/payload_utils.py +38 -18
  12. snowflake/ml/jobs/_utils/query_helper.py +8 -1
  13. snowflake/ml/jobs/_utils/runtime_env_utils.py +117 -0
  14. snowflake/ml/jobs/_utils/stage_utils.py +2 -2
  15. snowflake/ml/jobs/_utils/types.py +22 -2
  16. snowflake/ml/jobs/job_definition.py +232 -0
  17. snowflake/ml/jobs/manager.py +16 -177
  18. snowflake/ml/model/__init__.py +4 -0
  19. snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
  20. snowflake/ml/model/_client/model/model_version_impl.py +120 -89
  21. snowflake/ml/model/_client/ops/model_ops.py +4 -26
  22. snowflake/ml/model/_client/ops/param_utils.py +124 -0
  23. snowflake/ml/model/_client/ops/service_ops.py +63 -23
  24. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -5
  25. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  26. snowflake/ml/model/_client/sql/service.py +25 -54
  27. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
  28. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
  29. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
  30. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -1
  31. snowflake/ml/model/_packager/model_handlers/huggingface.py +74 -10
  32. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +121 -29
  33. snowflake/ml/model/_signatures/utils.py +130 -0
  34. snowflake/ml/model/openai_signatures.py +97 -0
  35. snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
  36. snowflake/ml/version.py +1 -1
  37. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/METADATA +105 -1
  38. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/RECORD +41 -35
  39. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/WHEEL +1 -1
  40. snowflake/ml/experiment/callback/__init__.py +0 -0
  41. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/licenses/LICENSE.txt +0 -0
  42. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,4 @@
1
- import base64
2
1
  import enum
3
- import json
4
2
  import pathlib
5
3
  import tempfile
6
4
  import uuid
@@ -8,7 +6,6 @@ import warnings
8
6
  from typing import Any, Callable, Optional, Union, overload
9
7
 
10
8
  import pandas as pd
11
- from pydantic import TypeAdapter
12
9
 
13
10
  from snowflake import snowpark
14
11
  from snowflake.ml._internal import telemetry
@@ -33,7 +30,10 @@ _TELEMETRY_PROJECT = "MLOps"
33
30
  _TELEMETRY_SUBPROJECT = "ModelManagement"
34
31
  _BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
35
32
  _BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
36
- _UTF8_ENCODING = "utf-8"
33
+ VLLM_SUPPORTED_TASKS = [
34
+ "text-generation",
35
+ "image-text-to-text",
36
+ ]
37
37
 
38
38
 
39
39
  class ExportMode(enum.Enum):
@@ -649,41 +649,6 @@ class ModelVersion(lineage_node.LineageNode):
649
649
  method_options, target_function_info["name"]
650
650
  )
651
651
 
652
- @staticmethod
653
- def _encode_column_handling(
654
- column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
655
- ) -> Optional[str]:
656
- """Validate and encode column_handling to a base64 string.
657
-
658
- Args:
659
- column_handling: Optional dictionary mapping column names to file encoding options.
660
-
661
- Returns:
662
- Base64 encoded JSON string of the column handling options, or None if input is None.
663
- """
664
- # TODO: validation for column names
665
- if column_handling is None:
666
- return None
667
- adapter = TypeAdapter(dict[str, batch_inference_specs.ColumnHandlingOptions])
668
- # TODO: throw error if the validate_python function fails
669
- validated_input = adapter.validate_python(column_handling)
670
- return base64.b64encode(adapter.dump_json(validated_input)).decode(_UTF8_ENCODING)
671
-
672
- @staticmethod
673
- def _encode_params(params: Optional[dict[str, Any]]) -> Optional[str]:
674
- """Encode params dictionary to a base64 string.
675
-
676
- Args:
677
- params: Optional dictionary of model inference parameters.
678
-
679
- Returns:
680
- Base64 encoded JSON string of the params, or None if input is None.
681
- """
682
- if params is None:
683
- return None
684
- # TODO: validation for param names, types
685
- return base64.b64encode(json.dumps(params).encode(_UTF8_ENCODING)).decode(_UTF8_ENCODING)
686
-
687
652
  @telemetry.send_api_usage_telemetry(
688
653
  project=_TELEMETRY_PROJECT,
689
654
  subproject=_TELEMETRY_SUBPROJECT,
@@ -696,32 +661,33 @@ class ModelVersion(lineage_node.LineageNode):
696
661
  @snowpark._internal.utils.private_preview(version="1.18.0")
697
662
  def run_batch(
698
663
  self,
664
+ X: dataframe.DataFrame,
699
665
  *,
700
666
  compute_pool: str,
701
- input_spec: dataframe.DataFrame,
667
+ input_spec: Optional[batch_inference_specs.InputSpec] = None,
702
668
  output_spec: batch_inference_specs.OutputSpec,
703
669
  job_spec: Optional[batch_inference_specs.JobSpec] = None,
704
- params: Optional[dict[str, Any]] = None,
705
- column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]] = None,
670
+ inference_engine_options: Optional[dict[str, Any]] = None,
706
671
  ) -> job.MLJob[Any]:
707
672
  """Execute batch inference on datasets as an SPCS job.
708
673
 
709
674
  Args:
710
675
  compute_pool (str): Name of the compute pool to use for building the image containers and batch
711
676
  inference execution.
712
- input_spec (dataframe.DataFrame): Snowpark DataFrame containing the input data for inference.
677
+ X (dataframe.DataFrame): Snowpark DataFrame containing the input data for inference.
713
678
  The DataFrame should contain all required features for model prediction and passthrough columns.
714
679
  output_spec (batch_inference_specs.OutputSpec): Configuration for where and how to save
715
680
  the inference results. Specifies the stage location and file handling behavior.
681
+ input_spec (Optional[batch_inference_specs.InputSpec]): Optional configuration for input
682
+ processing including model inference parameters and column handling options.
683
+ If None, default values will be used for params and column_handling.
716
684
  job_spec (Optional[batch_inference_specs.JobSpec]): Optional configuration for job
717
685
  execution parameters such as compute resources, worker counts, and job naming.
718
686
  If None, default values will be used.
719
- params (Optional[dict[str, Any]]): Optional dictionary of model inference parameters
720
- (e.g., temperature, top_k for LLMs). These are passed as keyword arguments to the
721
- model's inference method. Defaults to None.
722
- column_handling (Optional[dict[str, batch_inference_specs.FileEncoding]]): Optional dictionary
723
- specifying how to handle specific columns during file I/O. Maps column names to their
724
- file encoding configuration.
687
+ inference_engine_options: Options for the service creation with custom inference engine.
688
+ Supports `engine` and `engine_args_override`.
689
+ `engine` is the type of the inference engine to use.
690
+ `engine_args_override` is a list of string arguments to pass to the inference engine.
725
691
 
726
692
  Returns:
727
693
  job.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
@@ -729,7 +695,7 @@ class ModelVersion(lineage_node.LineageNode):
729
695
 
730
696
  Raises:
731
697
  ValueError: If warehouse is not set in job_spec and no current warehouse is available.
732
- RuntimeError: If the input_spec cannot be processed or written to the staging location.
698
+ RuntimeError: If the input data cannot be processed or written to the staging location.
733
699
 
734
700
  Example:
735
701
  >>> # Prepare input data - Example 1: From a table
@@ -762,10 +728,24 @@ class ModelVersion(lineage_node.LineageNode):
762
728
  >>> # Run batch inference
763
729
  >>> job = model_version.run_batch(
764
730
  ... compute_pool="my_compute_pool",
765
- ... input_spec=input_df,
731
+ ... X=input_df,
766
732
  ... output_spec=output_spec,
767
733
  ... job_spec=job_spec
768
734
  ... )
735
+ >>>
736
+ >>> # Run batch inference with InputSpec for additional options
737
+ >>> from snowflake.ml.model._client.model.batch_inference_specs import InputSpec, FileEncoding
738
+ >>> input_spec = InputSpec(
739
+ ... params={"temperature": 0.7, "top_k": 50},
740
+ ... column_handling={"image_col": {"encoding": FileEncoding.BASE64}}
741
+ ... )
742
+ >>> job = model_version.run_batch(
743
+ ... compute_pool="my_compute_pool",
744
+ ... X=input_df,
745
+ ... output_spec=output_spec,
746
+ ... input_spec=input_spec,
747
+ ... job_spec=job_spec
748
+ ... )
769
749
 
770
750
  Note:
771
751
  This method is currently in private preview and requires Snowflake version 1.18.0 or later.
@@ -777,12 +757,25 @@ class ModelVersion(lineage_node.LineageNode):
777
757
  subproject=_TELEMETRY_SUBPROJECT,
778
758
  )
779
759
 
780
- column_handling_as_string = self._encode_column_handling(column_handling)
781
- params_as_string = self._encode_params(params)
760
+ # Extract params and column_handling from input_spec if provided
761
+ if input_spec is None:
762
+ input_spec = batch_inference_specs.InputSpec()
763
+
764
+ params = input_spec.params
765
+ column_handling = input_spec.column_handling
782
766
 
783
767
  if job_spec is None:
784
768
  job_spec = batch_inference_specs.JobSpec()
785
769
 
770
+ # Validate GPU support if GPU resources are requested
771
+ self._throw_error_if_gpu_is_not_supported(job_spec.gpu_requests, statement_params)
772
+
773
+ inference_engine_args = self._prepare_inference_engine_args(
774
+ inference_engine_options,
775
+ job_spec.gpu_requests,
776
+ statement_params,
777
+ )
778
+
786
779
  warehouse = job_spec.warehouse or self._service_ops._session.get_current_warehouse()
787
780
  if warehouse is None:
788
781
  raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
@@ -796,10 +789,10 @@ class ModelVersion(lineage_node.LineageNode):
796
789
  self._service_ops._enforce_save_mode(output_spec.mode, output_stage_location)
797
790
 
798
791
  try:
799
- input_spec.write.copy_into_location(location=input_stage_location, file_format_type="parquet", header=True)
792
+ X.write.copy_into_location(location=input_stage_location, file_format_type="parquet", header=True)
800
793
  # todo: be specific about the type of errors to provide better error messages.
801
794
  except Exception as e:
802
- raise RuntimeError(f"Failed to process input_spec: {e}")
795
+ raise RuntimeError(f"Failed to process input data: {e}")
803
796
 
804
797
  if job_spec.job_name is None:
805
798
  # Same as the MLJob ID generation logic with a different prefix
@@ -807,12 +800,14 @@ class ModelVersion(lineage_node.LineageNode):
807
800
  else:
808
801
  job_name = job_spec.job_name
809
802
 
803
+ target_function_info = self._get_function_info(function_name=job_spec.function_name)
804
+
810
805
  return self._service_ops.invoke_batch_job_method(
811
806
  # model version info
812
807
  model_name=self._model_name,
813
808
  version_name=self._version_name,
814
809
  # job spec
815
- function_name=self._get_function_info(function_name=job_spec.function_name)["target_method"],
810
+ function_name=target_function_info["target_method"],
816
811
  compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
817
812
  force_rebuild=job_spec.force_rebuild,
818
813
  image_repo_name=job_spec.image_repo,
@@ -827,12 +822,14 @@ class ModelVersion(lineage_node.LineageNode):
827
822
  # input and output
828
823
  input_stage_location=input_stage_location,
829
824
  input_file_pattern="*",
830
- column_handling=column_handling_as_string,
831
- params=params_as_string,
825
+ column_handling=column_handling,
826
+ params=params,
827
+ signature_params=target_function_info["signature"].params,
832
828
  output_stage_location=output_stage_location,
833
829
  completion_filename="_SUCCESS",
834
830
  # misc
835
831
  statement_params=statement_params,
832
+ inference_engine_args=inference_engine_args,
836
833
  )
837
834
 
838
835
  def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
@@ -1048,20 +1045,55 @@ class ModelVersion(lineage_node.LineageNode):
1048
1045
  " the `log_model` function."
1049
1046
  )
1050
1047
 
1051
- def _check_huggingface_text_generation_model(
1048
+ def _prepare_inference_engine_args(
1049
+ self,
1050
+ inference_engine_options: Optional[dict[str, Any]],
1051
+ gpu_requests: Optional[Union[str, int]],
1052
+ statement_params: Optional[dict[str, Any]] = None,
1053
+ ) -> Optional[service_ops.InferenceEngineArgs]:
1054
+ """Prepare and validate inference engine arguments.
1055
+
1056
+ This method handles the common logic for processing inference engine options:
1057
+ 1. Parse inference engine options into InferenceEngineArgs
1058
+ 2. Validate that the model is a HuggingFace text-generation model (if inference engine is specified)
1059
+ 3. Enrich inference engine args
1060
+
1061
+ Args:
1062
+ inference_engine_options: Optional dictionary containing inference engine configuration.
1063
+ gpu_requests: GPU resource request string (e.g., "4").
1064
+ statement_params: Optional dictionary of statement parameters for SQL commands.
1065
+
1066
+ Returns:
1067
+ Prepared InferenceEngineArgs or None if no inference engine is specified.
1068
+ """
1069
+ inference_engine_args = inference_engine_utils._get_inference_engine_args(inference_engine_options)
1070
+
1071
+ if inference_engine_args is not None:
1072
+ # Validate that model is HuggingFace vLLM supported model and is logged with
1073
+ # OpenAI compatible signature.
1074
+ self._check_huggingface_vllm_supported_model(statement_params)
1075
+ # Enrich with GPU configuration
1076
+ inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
1077
+ inference_engine_args,
1078
+ gpu_requests,
1079
+ )
1080
+
1081
+ return inference_engine_args
1082
+
1083
+ def _check_huggingface_vllm_supported_model(
1052
1084
  self,
1053
1085
  statement_params: Optional[dict[str, Any]] = None,
1054
1086
  ) -> None:
1055
- """Check if the model is a HuggingFace pipeline with text-generation task
1056
- and is logged with OPENAI_CHAT_SIGNATURE.
1087
+ """Check if the model is a HuggingFace pipeline with vLLM supported task
1088
+ and is logged with OpenAI compatible signature.
1057
1089
 
1058
1090
  Args:
1059
1091
  statement_params: Optional dictionary of statement parameters to include
1060
1092
  in the SQL command to fetch model spec.
1061
1093
 
1062
1094
  Raises:
1063
- ValueError: If the model is not a HuggingFace text-generation model or
1064
- if the model is not logged with OPENAI_CHAT_SIGNATURE.
1095
+ ValueError: If the model is not a HuggingFace vLLM supported model or
1096
+ if the model is not logged with OpenAI compatible signature.
1065
1097
  """
1066
1098
  # Fetch model spec
1067
1099
  model_spec = self._get_model_spec(statement_params)
@@ -1070,34 +1102,37 @@ class ModelVersion(lineage_node.LineageNode):
1070
1102
  model_type = model_spec.get("model_type")
1071
1103
  if model_type != "huggingface_pipeline":
1072
1104
  raise ValueError(
1073
- f"Inference engine is only supported for HuggingFace text-generation models. "
1105
+ f"Inference engine is only supported for HuggingFace vLLM supported models. "
1074
1106
  f"Found model_type: {model_type}"
1075
1107
  )
1076
1108
 
1077
- # Check if model supports text-generation task
1109
+ # Check if model supports vLLM supported task
1078
1110
  # There should only be one model in the list because we don't support multiple models in a single model spec
1079
1111
  models = model_spec.get("models", {})
1080
- is_text_generation = False
1112
+ is_vllm_supported_task = False
1081
1113
  found_tasks: list[str] = []
1082
1114
 
1083
- # As long as the model supports text-generation task, we can use it
1115
+ # As long as the model supports vLLM supported task, we can use it
1084
1116
  for _, model_info in models.items():
1085
1117
  options = model_info.get("options", {})
1086
1118
  task = options.get("task")
1087
1119
  if task:
1088
1120
  found_tasks.append(str(task))
1089
- if task == "text-generation":
1090
- is_text_generation = True
1121
+ if task in VLLM_SUPPORTED_TASKS:
1122
+ is_vllm_supported_task = True
1091
1123
  break
1092
1124
 
1093
- if not is_text_generation:
1125
+ if not is_vllm_supported_task:
1094
1126
  tasks_str = ", ".join(found_tasks)
1095
1127
  found_tasks_str = (
1096
1128
  f"Found task(s): {tasks_str} in model spec." if found_tasks else "No task found in model spec."
1097
1129
  )
1098
- raise ValueError(f"Inference engine is only supported for task 'text-generation'. {found_tasks_str}")
1130
+ supported_tasks_str = ", ".join(VLLM_SUPPORTED_TASKS)
1131
+ raise ValueError(
1132
+ f"Inference engine is only supported for vLLM supported tasks. {supported_tasks_str}. {found_tasks_str}"
1133
+ )
1099
1134
 
1100
- # Check if the model is logged with OPENAI_CHAT_SIGNATURE
1135
+ # Check if the model is logged with OpenAI compatible signature.
1101
1136
  signatures_dict = model_spec.get("signatures", {})
1102
1137
 
1103
1138
  # Deserialize signatures from model spec to ModelSignature objects for proper semantic comparison.
@@ -1105,11 +1140,16 @@ class ModelVersion(lineage_node.LineageNode):
1105
1140
  func_name: core.ModelSignature.from_dict(sig_dict) for func_name, sig_dict in signatures_dict.items()
1106
1141
  }
1107
1142
 
1108
- if deserialized_signatures != openai_signatures.OPENAI_CHAT_SIGNATURE:
1143
+ if deserialized_signatures not in [
1144
+ openai_signatures.OPENAI_CHAT_SIGNATURE,
1145
+ openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING,
1146
+ ]:
1109
1147
  raise ValueError(
1110
- "Inference engine requires the model to be logged with OPENAI_CHAT_SIGNATURE. "
1148
+ "Inference engine requires the model to be logged with openai_signatures.OPENAI_CHAT_SIGNATURE or "
1149
+ "openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING. "
1111
1150
  f"Found signatures: {signatures_dict}. "
1112
- "Please log the model with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE"
1151
+ "Please log the model again with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE or "
1152
+ "signatures=openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING"
1113
1153
  )
1114
1154
 
1115
1155
  @overload
@@ -1350,20 +1390,11 @@ class ModelVersion(lineage_node.LineageNode):
1350
1390
  # Validate GPU support if GPU resources are requested
1351
1391
  self._throw_error_if_gpu_is_not_supported(gpu_requests, statement_params)
1352
1392
 
1353
- inference_engine_args = inference_engine_utils._get_inference_engine_args(inference_engine_options)
1354
-
1355
- # Check if model is HuggingFace text-generation and is logged with
1356
- # OPENAI_CHAT_SIGNATURE before doing inference engine checks
1357
- # Only validate if inference engine is actually specified
1358
- if inference_engine_args is not None:
1359
- self._check_huggingface_text_generation_model(statement_params)
1360
-
1361
- # Enrich inference engine args if inference engine is specified
1362
- if inference_engine_args is not None:
1363
- inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
1364
- inference_engine_args,
1365
- gpu_requests,
1366
- )
1393
+ inference_engine_args = self._prepare_inference_engine_args(
1394
+ inference_engine_options,
1395
+ gpu_requests,
1396
+ statement_params,
1397
+ )
1367
1398
 
1368
1399
  from snowflake.ml.model import event_handler
1369
1400
  from snowflake.snowpark import exceptions
@@ -10,11 +10,10 @@ from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
10
10
  import yaml
11
11
  from typing_extensions import NotRequired
12
12
 
13
- from snowflake.ml._internal import platform_capabilities
14
13
  from snowflake.ml._internal.exceptions import error_codes, exceptions
15
14
  from snowflake.ml._internal.utils import formatting, identifier, sql_identifier, url
16
15
  from snowflake.ml.model import model_signature, type_hints
17
- from snowflake.ml.model._client.ops import deployment_step, metadata_ops
16
+ from snowflake.ml.model._client.ops import deployment_step, metadata_ops, param_utils
18
17
  from snowflake.ml.model._client.sql import (
19
18
  model as model_sql,
20
19
  model_version as model_version_sql,
@@ -698,9 +697,6 @@ class ModelOperator:
698
697
 
699
698
  result: list[ServiceInfo] = []
700
699
  is_privatelink_connection = self._is_privatelink_connection()
701
- is_autocapture_param_enabled = (
702
- platform_capabilities.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
703
- )
704
700
 
705
701
  for fully_qualified_service_name in fully_qualified_service_names:
706
702
  port: Optional[int] = None
@@ -742,10 +738,8 @@ class ModelOperator:
742
738
  inference_endpoint=inference_endpoint,
743
739
  internal_endpoint=f"http://{internal_dns}:{port}" if port is not None else None,
744
740
  )
745
- if is_autocapture_param_enabled and self._service_client.DESC_SERVICE_SPEC_COL_NAME in service_description:
746
- # Include column only if parameter is enabled and spec exists for service owner caller
747
- autocapture_enabled = self._service_client.get_proxy_container_autocapture(service_description)
748
- service_info["autocapture_enabled"] = autocapture_enabled
741
+ autocapture_enabled = self._service_client.is_autocapture_enabled(service_description)
742
+ service_info["autocapture_enabled"] = autocapture_enabled
749
743
 
750
744
  result.append(service_info)
751
745
 
@@ -1063,23 +1057,7 @@ class ModelOperator:
1063
1057
  col_name = sql_identifier.SqlIdentifier(input_feature.name.upper(), case_sensitive=True)
1064
1058
  input_args.append(col_name)
1065
1059
 
1066
- method_parameters: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None
1067
- if signature.params:
1068
- # Start with defaults from signature
1069
- final_params = {}
1070
- for param_spec in signature.params:
1071
- if hasattr(param_spec, "default_value"):
1072
- final_params[param_spec.name] = param_spec.default_value
1073
-
1074
- # Override with provided runtime parameters
1075
- if params:
1076
- final_params.update(params)
1077
-
1078
- # Convert to list of tuples with SqlIdentifier for parameter names
1079
- method_parameters = [
1080
- (sql_identifier.SqlIdentifier(param_name), param_value)
1081
- for param_name, param_value in final_params.items()
1082
- ]
1060
+ method_parameters = param_utils.validate_and_resolve_params(params, signature.params)
1083
1061
 
1084
1062
  returns = []
1085
1063
  for output_feature in signature.outputs:
@@ -0,0 +1,124 @@
1
+ """Utility functions for model parameter validation and resolution."""
2
+
3
+ from typing import Any, Optional, Sequence
4
+
5
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
6
+ from snowflake.ml._internal.utils import sql_identifier
7
+ from snowflake.ml.model._signatures import core
8
+
9
+
10
+ def validate_params(
11
+ params: Optional[dict[str, Any]],
12
+ signature_params: Optional[Sequence[core.BaseParamSpec]],
13
+ ) -> None:
14
+ """Validate user-provided params against signature params.
15
+
16
+ Args:
17
+ params: User-provided parameter dictionary (runtime values).
18
+ signature_params: Parameter specifications from the model signature.
19
+
20
+ Raises:
21
+ SnowflakeMLException: If params are provided but signature has no params,
22
+ or if unknown params are provided, or if param types are invalid,
23
+ or if duplicate params are provided with different cases.
24
+ """
25
+ # Params provided but signature has no params defined
26
+ if params and not signature_params:
27
+ raise exceptions.SnowflakeMLException(
28
+ error_code=error_codes.INVALID_ARGUMENT,
29
+ original_exception=ValueError(
30
+ f"Parameters were provided ({sorted(params.keys())}), "
31
+ "but this method does not accept any parameters."
32
+ ),
33
+ )
34
+
35
+ if not signature_params or not params:
36
+ return
37
+
38
+ # Case-insensitive lookup: normalized_name -> param_spec
39
+ param_spec_lookup = {ps.name.upper(): ps for ps in signature_params}
40
+
41
+ # Check for duplicate params with different cases (e.g., "temperature" and "TEMPERATURE")
42
+ normalized_names = [name.upper() for name in params]
43
+ if len(normalized_names) != len(set(normalized_names)):
44
+ # Find the duplicate params to raise an error
45
+ param_seen: dict[str, list[str]] = {}
46
+ for param_name in params:
47
+ param_seen.setdefault(param_name.upper(), []).append(param_name)
48
+ duplicate_param_names = [param_names for param_names in param_seen.values() if len(param_names) > 1]
49
+ raise exceptions.SnowflakeMLException(
50
+ error_code=error_codes.INVALID_ARGUMENT,
51
+ original_exception=ValueError(
52
+ f"Duplicate parameter(s) provided with different cases: {duplicate_param_names}. "
53
+ "Parameter names are case-insensitive."
54
+ ),
55
+ )
56
+
57
+ # Validate user-provided params exist (case-insensitive)
58
+ invalid_params = [name for name in params if name.upper() not in param_spec_lookup]
59
+ if invalid_params:
60
+ raise exceptions.SnowflakeMLException(
61
+ error_code=error_codes.INVALID_ARGUMENT,
62
+ original_exception=ValueError(
63
+ f"Unknown parameter(s): {sorted(invalid_params)}. "
64
+ f"Valid parameters are: {sorted(ps.name for ps in signature_params)}"
65
+ ),
66
+ )
67
+
68
+ # Validate types for each provided param
69
+ for param_name, default_value in params.items():
70
+ param_spec = param_spec_lookup[param_name.upper()]
71
+ if isinstance(param_spec, core.ParamSpec):
72
+ core.ParamSpec._validate_default_value(param_spec.dtype, default_value, param_spec.shape)
73
+
74
+
75
+ def resolve_params(
76
+ params: Optional[dict[str, Any]],
77
+ signature_params: Sequence[core.BaseParamSpec],
78
+ ) -> list[tuple[sql_identifier.SqlIdentifier, Any]]:
79
+ """Resolve final method parameters by applying user-provided params over signature defaults.
80
+
81
+ Args:
82
+ params: User-provided parameter dictionary (runtime values).
83
+ signature_params: Parameter specifications from the model signature.
84
+
85
+ Returns:
86
+ List of tuples (SqlIdentifier, value) for method invocation.
87
+ """
88
+ # Case-insensitive lookup: normalized_name -> param_spec
89
+ param_spec_lookup = {ps.name.upper(): ps for ps in signature_params}
90
+
91
+ # Start with defaults from signature
92
+ final_params: dict[str, Any] = {}
93
+ for param_spec in signature_params:
94
+ if hasattr(param_spec, "default_value"):
95
+ final_params[param_spec.name] = param_spec.default_value
96
+
97
+ # Override with provided runtime parameters (using signature's original param names)
98
+ if params:
99
+ for param_name, override_value in params.items():
100
+ canonical_name = param_spec_lookup[param_name.upper()].name
101
+ final_params[canonical_name] = override_value
102
+
103
+ return [(sql_identifier.SqlIdentifier(param_name), param_value) for param_name, param_value in final_params.items()]
104
+
105
+
106
+ def validate_and_resolve_params(
107
+ params: Optional[dict[str, Any]],
108
+ signature_params: Optional[Sequence[core.BaseParamSpec]],
109
+ ) -> Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]]:
110
+ """Validate user-provided params against signature params and return method parameters.
111
+
112
+ Args:
113
+ params: User-provided parameter dictionary (runtime values).
114
+ signature_params: Parameter specifications from the model signature.
115
+
116
+ Returns:
117
+ List of tuples (SqlIdentifier, value) for method invocation, or None if no params.
118
+ """
119
+ validate_params(params, signature_params)
120
+
121
+ if not signature_params:
122
+ return None
123
+
124
+ return resolve_params(params, signature_params)