snowflake-ml-python 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +0 -4
- snowflake/ml/feature_store/__init__.py +2 -0
- snowflake/ml/feature_store/aggregation.py +367 -0
- snowflake/ml/feature_store/feature.py +366 -0
- snowflake/ml/feature_store/feature_store.py +234 -20
- snowflake/ml/feature_store/feature_view.py +189 -4
- snowflake/ml/feature_store/metadata_manager.py +425 -0
- snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +38 -18
- snowflake/ml/jobs/_utils/query_helper.py +8 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +117 -0
- snowflake/ml/jobs/_utils/stage_utils.py +2 -2
- snowflake/ml/jobs/_utils/types.py +22 -2
- snowflake/ml/jobs/job_definition.py +232 -0
- snowflake/ml/jobs/manager.py +16 -177
- snowflake/ml/model/__init__.py +4 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
- snowflake/ml/model/_client/model/model_version_impl.py +120 -89
- snowflake/ml/model/_client/ops/model_ops.py +4 -26
- snowflake/ml/model/_client/ops/param_utils.py +124 -0
- snowflake/ml/model/_client/ops/service_ops.py +63 -23
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/service.py +25 -54
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -1
- snowflake/ml/model/_packager/model_handlers/huggingface.py +74 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +121 -29
- snowflake/ml/model/_signatures/utils.py +130 -0
- snowflake/ml/model/openai_signatures.py +97 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/METADATA +105 -1
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/RECORD +41 -35
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/WHEEL +1 -1
- snowflake/ml/experiment/callback/__init__.py +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
import base64
|
|
2
1
|
import enum
|
|
3
|
-
import json
|
|
4
2
|
import pathlib
|
|
5
3
|
import tempfile
|
|
6
4
|
import uuid
|
|
@@ -8,7 +6,6 @@ import warnings
|
|
|
8
6
|
from typing import Any, Callable, Optional, Union, overload
|
|
9
7
|
|
|
10
8
|
import pandas as pd
|
|
11
|
-
from pydantic import TypeAdapter
|
|
12
9
|
|
|
13
10
|
from snowflake import snowpark
|
|
14
11
|
from snowflake.ml._internal import telemetry
|
|
@@ -33,7 +30,10 @@ _TELEMETRY_PROJECT = "MLOps"
|
|
|
33
30
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
|
34
31
|
_BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
|
|
35
32
|
_BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
|
|
36
|
-
|
|
33
|
+
VLLM_SUPPORTED_TASKS = [
|
|
34
|
+
"text-generation",
|
|
35
|
+
"image-text-to-text",
|
|
36
|
+
]
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
class ExportMode(enum.Enum):
|
|
@@ -649,41 +649,6 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
649
649
|
method_options, target_function_info["name"]
|
|
650
650
|
)
|
|
651
651
|
|
|
652
|
-
@staticmethod
|
|
653
|
-
def _encode_column_handling(
|
|
654
|
-
column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
|
|
655
|
-
) -> Optional[str]:
|
|
656
|
-
"""Validate and encode column_handling to a base64 string.
|
|
657
|
-
|
|
658
|
-
Args:
|
|
659
|
-
column_handling: Optional dictionary mapping column names to file encoding options.
|
|
660
|
-
|
|
661
|
-
Returns:
|
|
662
|
-
Base64 encoded JSON string of the column handling options, or None if input is None.
|
|
663
|
-
"""
|
|
664
|
-
# TODO: validation for column names
|
|
665
|
-
if column_handling is None:
|
|
666
|
-
return None
|
|
667
|
-
adapter = TypeAdapter(dict[str, batch_inference_specs.ColumnHandlingOptions])
|
|
668
|
-
# TODO: throw error if the validate_python function fails
|
|
669
|
-
validated_input = adapter.validate_python(column_handling)
|
|
670
|
-
return base64.b64encode(adapter.dump_json(validated_input)).decode(_UTF8_ENCODING)
|
|
671
|
-
|
|
672
|
-
@staticmethod
|
|
673
|
-
def _encode_params(params: Optional[dict[str, Any]]) -> Optional[str]:
|
|
674
|
-
"""Encode params dictionary to a base64 string.
|
|
675
|
-
|
|
676
|
-
Args:
|
|
677
|
-
params: Optional dictionary of model inference parameters.
|
|
678
|
-
|
|
679
|
-
Returns:
|
|
680
|
-
Base64 encoded JSON string of the params, or None if input is None.
|
|
681
|
-
"""
|
|
682
|
-
if params is None:
|
|
683
|
-
return None
|
|
684
|
-
# TODO: validation for param names, types
|
|
685
|
-
return base64.b64encode(json.dumps(params).encode(_UTF8_ENCODING)).decode(_UTF8_ENCODING)
|
|
686
|
-
|
|
687
652
|
@telemetry.send_api_usage_telemetry(
|
|
688
653
|
project=_TELEMETRY_PROJECT,
|
|
689
654
|
subproject=_TELEMETRY_SUBPROJECT,
|
|
@@ -696,32 +661,33 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
696
661
|
@snowpark._internal.utils.private_preview(version="1.18.0")
|
|
697
662
|
def run_batch(
|
|
698
663
|
self,
|
|
664
|
+
X: dataframe.DataFrame,
|
|
699
665
|
*,
|
|
700
666
|
compute_pool: str,
|
|
701
|
-
input_spec:
|
|
667
|
+
input_spec: Optional[batch_inference_specs.InputSpec] = None,
|
|
702
668
|
output_spec: batch_inference_specs.OutputSpec,
|
|
703
669
|
job_spec: Optional[batch_inference_specs.JobSpec] = None,
|
|
704
|
-
|
|
705
|
-
column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]] = None,
|
|
670
|
+
inference_engine_options: Optional[dict[str, Any]] = None,
|
|
706
671
|
) -> job.MLJob[Any]:
|
|
707
672
|
"""Execute batch inference on datasets as an SPCS job.
|
|
708
673
|
|
|
709
674
|
Args:
|
|
710
675
|
compute_pool (str): Name of the compute pool to use for building the image containers and batch
|
|
711
676
|
inference execution.
|
|
712
|
-
|
|
677
|
+
X (dataframe.DataFrame): Snowpark DataFrame containing the input data for inference.
|
|
713
678
|
The DataFrame should contain all required features for model prediction and passthrough columns.
|
|
714
679
|
output_spec (batch_inference_specs.OutputSpec): Configuration for where and how to save
|
|
715
680
|
the inference results. Specifies the stage location and file handling behavior.
|
|
681
|
+
input_spec (Optional[batch_inference_specs.InputSpec]): Optional configuration for input
|
|
682
|
+
processing including model inference parameters and column handling options.
|
|
683
|
+
If None, default values will be used for params and column_handling.
|
|
716
684
|
job_spec (Optional[batch_inference_specs.JobSpec]): Optional configuration for job
|
|
717
685
|
execution parameters such as compute resources, worker counts, and job naming.
|
|
718
686
|
If None, default values will be used.
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
specifying how to handle specific columns during file I/O. Maps column names to their
|
|
724
|
-
file encoding configuration.
|
|
687
|
+
inference_engine_options: Options for the service creation with custom inference engine.
|
|
688
|
+
Supports `engine` and `engine_args_override`.
|
|
689
|
+
`engine` is the type of the inference engine to use.
|
|
690
|
+
`engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
725
691
|
|
|
726
692
|
Returns:
|
|
727
693
|
job.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
|
|
@@ -729,7 +695,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
729
695
|
|
|
730
696
|
Raises:
|
|
731
697
|
ValueError: If warehouse is not set in job_spec and no current warehouse is available.
|
|
732
|
-
RuntimeError: If the
|
|
698
|
+
RuntimeError: If the input data cannot be processed or written to the staging location.
|
|
733
699
|
|
|
734
700
|
Example:
|
|
735
701
|
>>> # Prepare input data - Example 1: From a table
|
|
@@ -762,10 +728,24 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
762
728
|
>>> # Run batch inference
|
|
763
729
|
>>> job = model_version.run_batch(
|
|
764
730
|
... compute_pool="my_compute_pool",
|
|
765
|
-
...
|
|
731
|
+
... X=input_df,
|
|
766
732
|
... output_spec=output_spec,
|
|
767
733
|
... job_spec=job_spec
|
|
768
734
|
... )
|
|
735
|
+
>>>
|
|
736
|
+
>>> # Run batch inference with InputSpec for additional options
|
|
737
|
+
>>> from snowflake.ml.model._client.model.batch_inference_specs import InputSpec, FileEncoding
|
|
738
|
+
>>> input_spec = InputSpec(
|
|
739
|
+
... params={"temperature": 0.7, "top_k": 50},
|
|
740
|
+
... column_handling={"image_col": {"encoding": FileEncoding.BASE64}}
|
|
741
|
+
... )
|
|
742
|
+
>>> job = model_version.run_batch(
|
|
743
|
+
... compute_pool="my_compute_pool",
|
|
744
|
+
... X=input_df,
|
|
745
|
+
... output_spec=output_spec,
|
|
746
|
+
... input_spec=input_spec,
|
|
747
|
+
... job_spec=job_spec
|
|
748
|
+
... )
|
|
769
749
|
|
|
770
750
|
Note:
|
|
771
751
|
This method is currently in private preview and requires Snowflake version 1.18.0 or later.
|
|
@@ -777,12 +757,25 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
777
757
|
subproject=_TELEMETRY_SUBPROJECT,
|
|
778
758
|
)
|
|
779
759
|
|
|
780
|
-
|
|
781
|
-
|
|
760
|
+
# Extract params and column_handling from input_spec if provided
|
|
761
|
+
if input_spec is None:
|
|
762
|
+
input_spec = batch_inference_specs.InputSpec()
|
|
763
|
+
|
|
764
|
+
params = input_spec.params
|
|
765
|
+
column_handling = input_spec.column_handling
|
|
782
766
|
|
|
783
767
|
if job_spec is None:
|
|
784
768
|
job_spec = batch_inference_specs.JobSpec()
|
|
785
769
|
|
|
770
|
+
# Validate GPU support if GPU resources are requested
|
|
771
|
+
self._throw_error_if_gpu_is_not_supported(job_spec.gpu_requests, statement_params)
|
|
772
|
+
|
|
773
|
+
inference_engine_args = self._prepare_inference_engine_args(
|
|
774
|
+
inference_engine_options,
|
|
775
|
+
job_spec.gpu_requests,
|
|
776
|
+
statement_params,
|
|
777
|
+
)
|
|
778
|
+
|
|
786
779
|
warehouse = job_spec.warehouse or self._service_ops._session.get_current_warehouse()
|
|
787
780
|
if warehouse is None:
|
|
788
781
|
raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
|
|
@@ -796,10 +789,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
796
789
|
self._service_ops._enforce_save_mode(output_spec.mode, output_stage_location)
|
|
797
790
|
|
|
798
791
|
try:
|
|
799
|
-
|
|
792
|
+
X.write.copy_into_location(location=input_stage_location, file_format_type="parquet", header=True)
|
|
800
793
|
# todo: be specific about the type of errors to provide better error messages.
|
|
801
794
|
except Exception as e:
|
|
802
|
-
raise RuntimeError(f"Failed to process
|
|
795
|
+
raise RuntimeError(f"Failed to process input data: {e}")
|
|
803
796
|
|
|
804
797
|
if job_spec.job_name is None:
|
|
805
798
|
# Same as the MLJob ID generation logic with a different prefix
|
|
@@ -807,12 +800,14 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
807
800
|
else:
|
|
808
801
|
job_name = job_spec.job_name
|
|
809
802
|
|
|
803
|
+
target_function_info = self._get_function_info(function_name=job_spec.function_name)
|
|
804
|
+
|
|
810
805
|
return self._service_ops.invoke_batch_job_method(
|
|
811
806
|
# model version info
|
|
812
807
|
model_name=self._model_name,
|
|
813
808
|
version_name=self._version_name,
|
|
814
809
|
# job spec
|
|
815
|
-
function_name=
|
|
810
|
+
function_name=target_function_info["target_method"],
|
|
816
811
|
compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
|
|
817
812
|
force_rebuild=job_spec.force_rebuild,
|
|
818
813
|
image_repo_name=job_spec.image_repo,
|
|
@@ -827,12 +822,14 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
827
822
|
# input and output
|
|
828
823
|
input_stage_location=input_stage_location,
|
|
829
824
|
input_file_pattern="*",
|
|
830
|
-
column_handling=
|
|
831
|
-
params=
|
|
825
|
+
column_handling=column_handling,
|
|
826
|
+
params=params,
|
|
827
|
+
signature_params=target_function_info["signature"].params,
|
|
832
828
|
output_stage_location=output_stage_location,
|
|
833
829
|
completion_filename="_SUCCESS",
|
|
834
830
|
# misc
|
|
835
831
|
statement_params=statement_params,
|
|
832
|
+
inference_engine_args=inference_engine_args,
|
|
836
833
|
)
|
|
837
834
|
|
|
838
835
|
def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
|
|
@@ -1048,20 +1045,55 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1048
1045
|
" the `log_model` function."
|
|
1049
1046
|
)
|
|
1050
1047
|
|
|
1051
|
-
def
|
|
1048
|
+
def _prepare_inference_engine_args(
|
|
1049
|
+
self,
|
|
1050
|
+
inference_engine_options: Optional[dict[str, Any]],
|
|
1051
|
+
gpu_requests: Optional[Union[str, int]],
|
|
1052
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
1053
|
+
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
1054
|
+
"""Prepare and validate inference engine arguments.
|
|
1055
|
+
|
|
1056
|
+
This method handles the common logic for processing inference engine options:
|
|
1057
|
+
1. Parse inference engine options into InferenceEngineArgs
|
|
1058
|
+
2. Validate that the model is a HuggingFace text-generation model (if inference engine is specified)
|
|
1059
|
+
3. Enrich inference engine args
|
|
1060
|
+
|
|
1061
|
+
Args:
|
|
1062
|
+
inference_engine_options: Optional dictionary containing inference engine configuration.
|
|
1063
|
+
gpu_requests: GPU resource request string (e.g., "4").
|
|
1064
|
+
statement_params: Optional dictionary of statement parameters for SQL commands.
|
|
1065
|
+
|
|
1066
|
+
Returns:
|
|
1067
|
+
Prepared InferenceEngineArgs or None if no inference engine is specified.
|
|
1068
|
+
"""
|
|
1069
|
+
inference_engine_args = inference_engine_utils._get_inference_engine_args(inference_engine_options)
|
|
1070
|
+
|
|
1071
|
+
if inference_engine_args is not None:
|
|
1072
|
+
# Validate that model is HuggingFace vLLM supported model and is logged with
|
|
1073
|
+
# OpenAI compatible signature.
|
|
1074
|
+
self._check_huggingface_vllm_supported_model(statement_params)
|
|
1075
|
+
# Enrich with GPU configuration
|
|
1076
|
+
inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
|
|
1077
|
+
inference_engine_args,
|
|
1078
|
+
gpu_requests,
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1081
|
+
return inference_engine_args
|
|
1082
|
+
|
|
1083
|
+
def _check_huggingface_vllm_supported_model(
|
|
1052
1084
|
self,
|
|
1053
1085
|
statement_params: Optional[dict[str, Any]] = None,
|
|
1054
1086
|
) -> None:
|
|
1055
|
-
"""Check if the model is a HuggingFace pipeline with
|
|
1056
|
-
and is logged with
|
|
1087
|
+
"""Check if the model is a HuggingFace pipeline with vLLM supported task
|
|
1088
|
+
and is logged with OpenAI compatible signature.
|
|
1057
1089
|
|
|
1058
1090
|
Args:
|
|
1059
1091
|
statement_params: Optional dictionary of statement parameters to include
|
|
1060
1092
|
in the SQL command to fetch model spec.
|
|
1061
1093
|
|
|
1062
1094
|
Raises:
|
|
1063
|
-
ValueError: If the model is not a HuggingFace
|
|
1064
|
-
if the model is not logged with
|
|
1095
|
+
ValueError: If the model is not a HuggingFace vLLM supported model or
|
|
1096
|
+
if the model is not logged with OpenAI compatible signature.
|
|
1065
1097
|
"""
|
|
1066
1098
|
# Fetch model spec
|
|
1067
1099
|
model_spec = self._get_model_spec(statement_params)
|
|
@@ -1070,34 +1102,37 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1070
1102
|
model_type = model_spec.get("model_type")
|
|
1071
1103
|
if model_type != "huggingface_pipeline":
|
|
1072
1104
|
raise ValueError(
|
|
1073
|
-
f"Inference engine is only supported for HuggingFace
|
|
1105
|
+
f"Inference engine is only supported for HuggingFace vLLM supported models. "
|
|
1074
1106
|
f"Found model_type: {model_type}"
|
|
1075
1107
|
)
|
|
1076
1108
|
|
|
1077
|
-
# Check if model supports
|
|
1109
|
+
# Check if model supports vLLM supported task
|
|
1078
1110
|
# There should only be one model in the list because we don't support multiple models in a single model spec
|
|
1079
1111
|
models = model_spec.get("models", {})
|
|
1080
|
-
|
|
1112
|
+
is_vllm_supported_task = False
|
|
1081
1113
|
found_tasks: list[str] = []
|
|
1082
1114
|
|
|
1083
|
-
# As long as the model supports
|
|
1115
|
+
# As long as the model supports vLLM supported task, we can use it
|
|
1084
1116
|
for _, model_info in models.items():
|
|
1085
1117
|
options = model_info.get("options", {})
|
|
1086
1118
|
task = options.get("task")
|
|
1087
1119
|
if task:
|
|
1088
1120
|
found_tasks.append(str(task))
|
|
1089
|
-
if task
|
|
1090
|
-
|
|
1121
|
+
if task in VLLM_SUPPORTED_TASKS:
|
|
1122
|
+
is_vllm_supported_task = True
|
|
1091
1123
|
break
|
|
1092
1124
|
|
|
1093
|
-
if not
|
|
1125
|
+
if not is_vllm_supported_task:
|
|
1094
1126
|
tasks_str = ", ".join(found_tasks)
|
|
1095
1127
|
found_tasks_str = (
|
|
1096
1128
|
f"Found task(s): {tasks_str} in model spec." if found_tasks else "No task found in model spec."
|
|
1097
1129
|
)
|
|
1098
|
-
|
|
1130
|
+
supported_tasks_str = ", ".join(VLLM_SUPPORTED_TASKS)
|
|
1131
|
+
raise ValueError(
|
|
1132
|
+
f"Inference engine is only supported for vLLM supported tasks. {supported_tasks_str}. {found_tasks_str}"
|
|
1133
|
+
)
|
|
1099
1134
|
|
|
1100
|
-
# Check if the model is logged with
|
|
1135
|
+
# Check if the model is logged with OpenAI compatible signature.
|
|
1101
1136
|
signatures_dict = model_spec.get("signatures", {})
|
|
1102
1137
|
|
|
1103
1138
|
# Deserialize signatures from model spec to ModelSignature objects for proper semantic comparison.
|
|
@@ -1105,11 +1140,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1105
1140
|
func_name: core.ModelSignature.from_dict(sig_dict) for func_name, sig_dict in signatures_dict.items()
|
|
1106
1141
|
}
|
|
1107
1142
|
|
|
1108
|
-
if deserialized_signatures
|
|
1143
|
+
if deserialized_signatures not in [
|
|
1144
|
+
openai_signatures.OPENAI_CHAT_SIGNATURE,
|
|
1145
|
+
openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING,
|
|
1146
|
+
]:
|
|
1109
1147
|
raise ValueError(
|
|
1110
|
-
"Inference engine requires the model to be logged with OPENAI_CHAT_SIGNATURE
|
|
1148
|
+
"Inference engine requires the model to be logged with openai_signatures.OPENAI_CHAT_SIGNATURE or "
|
|
1149
|
+
"openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING. "
|
|
1111
1150
|
f"Found signatures: {signatures_dict}. "
|
|
1112
|
-
"Please log the model with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE"
|
|
1151
|
+
"Please log the model again with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE or "
|
|
1152
|
+
"signatures=openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING"
|
|
1113
1153
|
)
|
|
1114
1154
|
|
|
1115
1155
|
@overload
|
|
@@ -1350,20 +1390,11 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1350
1390
|
# Validate GPU support if GPU resources are requested
|
|
1351
1391
|
self._throw_error_if_gpu_is_not_supported(gpu_requests, statement_params)
|
|
1352
1392
|
|
|
1353
|
-
inference_engine_args =
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
if inference_engine_args is not None:
|
|
1359
|
-
self._check_huggingface_text_generation_model(statement_params)
|
|
1360
|
-
|
|
1361
|
-
# Enrich inference engine args if inference engine is specified
|
|
1362
|
-
if inference_engine_args is not None:
|
|
1363
|
-
inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
|
|
1364
|
-
inference_engine_args,
|
|
1365
|
-
gpu_requests,
|
|
1366
|
-
)
|
|
1393
|
+
inference_engine_args = self._prepare_inference_engine_args(
|
|
1394
|
+
inference_engine_options,
|
|
1395
|
+
gpu_requests,
|
|
1396
|
+
statement_params,
|
|
1397
|
+
)
|
|
1367
1398
|
|
|
1368
1399
|
from snowflake.ml.model import event_handler
|
|
1369
1400
|
from snowflake.snowpark import exceptions
|
|
@@ -10,11 +10,10 @@ from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
|
|
|
10
10
|
import yaml
|
|
11
11
|
from typing_extensions import NotRequired
|
|
12
12
|
|
|
13
|
-
from snowflake.ml._internal import platform_capabilities
|
|
14
13
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
15
14
|
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier, url
|
|
16
15
|
from snowflake.ml.model import model_signature, type_hints
|
|
17
|
-
from snowflake.ml.model._client.ops import deployment_step, metadata_ops
|
|
16
|
+
from snowflake.ml.model._client.ops import deployment_step, metadata_ops, param_utils
|
|
18
17
|
from snowflake.ml.model._client.sql import (
|
|
19
18
|
model as model_sql,
|
|
20
19
|
model_version as model_version_sql,
|
|
@@ -698,9 +697,6 @@ class ModelOperator:
|
|
|
698
697
|
|
|
699
698
|
result: list[ServiceInfo] = []
|
|
700
699
|
is_privatelink_connection = self._is_privatelink_connection()
|
|
701
|
-
is_autocapture_param_enabled = (
|
|
702
|
-
platform_capabilities.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
|
|
703
|
-
)
|
|
704
700
|
|
|
705
701
|
for fully_qualified_service_name in fully_qualified_service_names:
|
|
706
702
|
port: Optional[int] = None
|
|
@@ -742,10 +738,8 @@ class ModelOperator:
|
|
|
742
738
|
inference_endpoint=inference_endpoint,
|
|
743
739
|
internal_endpoint=f"http://{internal_dns}:{port}" if port is not None else None,
|
|
744
740
|
)
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
autocapture_enabled = self._service_client.get_proxy_container_autocapture(service_description)
|
|
748
|
-
service_info["autocapture_enabled"] = autocapture_enabled
|
|
741
|
+
autocapture_enabled = self._service_client.is_autocapture_enabled(service_description)
|
|
742
|
+
service_info["autocapture_enabled"] = autocapture_enabled
|
|
749
743
|
|
|
750
744
|
result.append(service_info)
|
|
751
745
|
|
|
@@ -1063,23 +1057,7 @@ class ModelOperator:
|
|
|
1063
1057
|
col_name = sql_identifier.SqlIdentifier(input_feature.name.upper(), case_sensitive=True)
|
|
1064
1058
|
input_args.append(col_name)
|
|
1065
1059
|
|
|
1066
|
-
method_parameters
|
|
1067
|
-
if signature.params:
|
|
1068
|
-
# Start with defaults from signature
|
|
1069
|
-
final_params = {}
|
|
1070
|
-
for param_spec in signature.params:
|
|
1071
|
-
if hasattr(param_spec, "default_value"):
|
|
1072
|
-
final_params[param_spec.name] = param_spec.default_value
|
|
1073
|
-
|
|
1074
|
-
# Override with provided runtime parameters
|
|
1075
|
-
if params:
|
|
1076
|
-
final_params.update(params)
|
|
1077
|
-
|
|
1078
|
-
# Convert to list of tuples with SqlIdentifier for parameter names
|
|
1079
|
-
method_parameters = [
|
|
1080
|
-
(sql_identifier.SqlIdentifier(param_name), param_value)
|
|
1081
|
-
for param_name, param_value in final_params.items()
|
|
1082
|
-
]
|
|
1060
|
+
method_parameters = param_utils.validate_and_resolve_params(params, signature.params)
|
|
1083
1061
|
|
|
1084
1062
|
returns = []
|
|
1085
1063
|
for output_feature in signature.outputs:
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Utility functions for model parameter validation and resolution."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional, Sequence
|
|
4
|
+
|
|
5
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
6
|
+
from snowflake.ml._internal.utils import sql_identifier
|
|
7
|
+
from snowflake.ml.model._signatures import core
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def validate_params(
|
|
11
|
+
params: Optional[dict[str, Any]],
|
|
12
|
+
signature_params: Optional[Sequence[core.BaseParamSpec]],
|
|
13
|
+
) -> None:
|
|
14
|
+
"""Validate user-provided params against signature params.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
params: User-provided parameter dictionary (runtime values).
|
|
18
|
+
signature_params: Parameter specifications from the model signature.
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
SnowflakeMLException: If params are provided but signature has no params,
|
|
22
|
+
or if unknown params are provided, or if param types are invalid,
|
|
23
|
+
or if duplicate params are provided with different cases.
|
|
24
|
+
"""
|
|
25
|
+
# Params provided but signature has no params defined
|
|
26
|
+
if params and not signature_params:
|
|
27
|
+
raise exceptions.SnowflakeMLException(
|
|
28
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
29
|
+
original_exception=ValueError(
|
|
30
|
+
f"Parameters were provided ({sorted(params.keys())}), "
|
|
31
|
+
"but this method does not accept any parameters."
|
|
32
|
+
),
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if not signature_params or not params:
|
|
36
|
+
return
|
|
37
|
+
|
|
38
|
+
# Case-insensitive lookup: normalized_name -> param_spec
|
|
39
|
+
param_spec_lookup = {ps.name.upper(): ps for ps in signature_params}
|
|
40
|
+
|
|
41
|
+
# Check for duplicate params with different cases (e.g., "temperature" and "TEMPERATURE")
|
|
42
|
+
normalized_names = [name.upper() for name in params]
|
|
43
|
+
if len(normalized_names) != len(set(normalized_names)):
|
|
44
|
+
# Find the duplicate params to raise an error
|
|
45
|
+
param_seen: dict[str, list[str]] = {}
|
|
46
|
+
for param_name in params:
|
|
47
|
+
param_seen.setdefault(param_name.upper(), []).append(param_name)
|
|
48
|
+
duplicate_param_names = [param_names for param_names in param_seen.values() if len(param_names) > 1]
|
|
49
|
+
raise exceptions.SnowflakeMLException(
|
|
50
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
51
|
+
original_exception=ValueError(
|
|
52
|
+
f"Duplicate parameter(s) provided with different cases: {duplicate_param_names}. "
|
|
53
|
+
"Parameter names are case-insensitive."
|
|
54
|
+
),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Validate user-provided params exist (case-insensitive)
|
|
58
|
+
invalid_params = [name for name in params if name.upper() not in param_spec_lookup]
|
|
59
|
+
if invalid_params:
|
|
60
|
+
raise exceptions.SnowflakeMLException(
|
|
61
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
62
|
+
original_exception=ValueError(
|
|
63
|
+
f"Unknown parameter(s): {sorted(invalid_params)}. "
|
|
64
|
+
f"Valid parameters are: {sorted(ps.name for ps in signature_params)}"
|
|
65
|
+
),
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Validate types for each provided param
|
|
69
|
+
for param_name, default_value in params.items():
|
|
70
|
+
param_spec = param_spec_lookup[param_name.upper()]
|
|
71
|
+
if isinstance(param_spec, core.ParamSpec):
|
|
72
|
+
core.ParamSpec._validate_default_value(param_spec.dtype, default_value, param_spec.shape)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def resolve_params(
|
|
76
|
+
params: Optional[dict[str, Any]],
|
|
77
|
+
signature_params: Sequence[core.BaseParamSpec],
|
|
78
|
+
) -> list[tuple[sql_identifier.SqlIdentifier, Any]]:
|
|
79
|
+
"""Resolve final method parameters by applying user-provided params over signature defaults.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
params: User-provided parameter dictionary (runtime values).
|
|
83
|
+
signature_params: Parameter specifications from the model signature.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
List of tuples (SqlIdentifier, value) for method invocation.
|
|
87
|
+
"""
|
|
88
|
+
# Case-insensitive lookup: normalized_name -> param_spec
|
|
89
|
+
param_spec_lookup = {ps.name.upper(): ps for ps in signature_params}
|
|
90
|
+
|
|
91
|
+
# Start with defaults from signature
|
|
92
|
+
final_params: dict[str, Any] = {}
|
|
93
|
+
for param_spec in signature_params:
|
|
94
|
+
if hasattr(param_spec, "default_value"):
|
|
95
|
+
final_params[param_spec.name] = param_spec.default_value
|
|
96
|
+
|
|
97
|
+
# Override with provided runtime parameters (using signature's original param names)
|
|
98
|
+
if params:
|
|
99
|
+
for param_name, override_value in params.items():
|
|
100
|
+
canonical_name = param_spec_lookup[param_name.upper()].name
|
|
101
|
+
final_params[canonical_name] = override_value
|
|
102
|
+
|
|
103
|
+
return [(sql_identifier.SqlIdentifier(param_name), param_value) for param_name, param_value in final_params.items()]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def validate_and_resolve_params(
|
|
107
|
+
params: Optional[dict[str, Any]],
|
|
108
|
+
signature_params: Optional[Sequence[core.BaseParamSpec]],
|
|
109
|
+
) -> Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]]:
|
|
110
|
+
"""Validate user-provided params against signature params and return method parameters.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
params: User-provided parameter dictionary (runtime values).
|
|
114
|
+
signature_params: Parameter specifications from the model signature.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
List of tuples (SqlIdentifier, value) for method invocation, or None if no params.
|
|
118
|
+
"""
|
|
119
|
+
validate_params(params, signature_params)
|
|
120
|
+
|
|
121
|
+
if not signature_params:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
return resolve_params(params, signature_params)
|