snowflake-ml-python 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/experiment/callback/keras.py +63 -0
- snowflake/ml/experiment/callback/lightgbm.py +5 -1
- snowflake/ml/experiment/callback/xgboost.py +5 -1
- snowflake/ml/jobs/_utils/__init__.py +0 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/payload_utils.py +42 -14
- snowflake/ml/jobs/_utils/query_helper.py +5 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +2 -2
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +3 -3
- snowflake/ml/jobs/_utils/spec_utils.py +41 -8
- snowflake/ml/jobs/_utils/stage_utils.py +22 -9
- snowflake/ml/jobs/_utils/types.py +5 -7
- snowflake/ml/jobs/job.py +1 -1
- snowflake/ml/jobs/manager.py +1 -13
- snowflake/ml/model/_client/model/model_version_impl.py +166 -10
- snowflake/ml/model/_client/ops/service_ops.py +63 -28
- snowflake/ml/model/_client/service/model_deployment_spec.py +103 -27
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +11 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -70
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
- snowflake/ml/model/inference_engine.py +5 -0
- snowflake/ml/model/models/huggingface_pipeline.py +4 -3
- snowflake/ml/registry/_manager/model_manager.py +7 -35
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/METADATA +23 -4
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/RECORD +31 -27
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -707,6 +707,128 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
707
707
|
version_name=sql_identifier.SqlIdentifier(version),
|
|
708
708
|
)
|
|
709
709
|
|
|
710
|
+
def _get_inference_engine_args(
|
|
711
|
+
self, experimental_options: Optional[dict[str, Any]]
|
|
712
|
+
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
713
|
+
|
|
714
|
+
if not experimental_options:
|
|
715
|
+
return None
|
|
716
|
+
|
|
717
|
+
if "inference_engine" not in experimental_options:
|
|
718
|
+
raise ValueError("inference_engine is required in experimental_options")
|
|
719
|
+
|
|
720
|
+
return service_ops.InferenceEngineArgs(
|
|
721
|
+
inference_engine=experimental_options["inference_engine"],
|
|
722
|
+
inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
def _enrich_inference_engine_args(
|
|
726
|
+
self,
|
|
727
|
+
inference_engine_args: service_ops.InferenceEngineArgs,
|
|
728
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
|
729
|
+
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
730
|
+
"""Enrich inference engine args with model path and tensor parallelism settings.
|
|
731
|
+
|
|
732
|
+
Args:
|
|
733
|
+
inference_engine_args: The original inference engine args
|
|
734
|
+
gpu_requests: The number of GPUs requested
|
|
735
|
+
|
|
736
|
+
Returns:
|
|
737
|
+
Enriched inference engine args
|
|
738
|
+
|
|
739
|
+
Raises:
|
|
740
|
+
ValueError: Invalid gpu_requests
|
|
741
|
+
"""
|
|
742
|
+
if inference_engine_args.inference_engine_args_override is None:
|
|
743
|
+
inference_engine_args.inference_engine_args_override = []
|
|
744
|
+
|
|
745
|
+
# Get model stage path and strip off "snow://" prefix
|
|
746
|
+
model_stage_path = self._model_ops.get_model_version_stage_path(
|
|
747
|
+
database_name=None,
|
|
748
|
+
schema_name=None,
|
|
749
|
+
model_name=self._model_name,
|
|
750
|
+
version_name=self._version_name,
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
# Strip "snow://" prefix
|
|
754
|
+
if model_stage_path.startswith("snow://"):
|
|
755
|
+
model_stage_path = model_stage_path.replace("snow://", "", 1)
|
|
756
|
+
|
|
757
|
+
# Always overwrite the model key by appending
|
|
758
|
+
inference_engine_args.inference_engine_args_override.append(f"--model={model_stage_path}")
|
|
759
|
+
|
|
760
|
+
gpu_count = None
|
|
761
|
+
|
|
762
|
+
# Set tensor-parallelism if gpu_requests is specified
|
|
763
|
+
if gpu_requests is not None:
|
|
764
|
+
# assert gpu_requests is a string or an integer before casting to int
|
|
765
|
+
if isinstance(gpu_requests, str) or isinstance(gpu_requests, int):
|
|
766
|
+
try:
|
|
767
|
+
gpu_count = int(gpu_requests)
|
|
768
|
+
except ValueError:
|
|
769
|
+
raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
|
|
770
|
+
|
|
771
|
+
if gpu_count is not None:
|
|
772
|
+
if gpu_count > 0:
|
|
773
|
+
inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
|
|
774
|
+
else:
|
|
775
|
+
raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
|
|
776
|
+
|
|
777
|
+
return inference_engine_args
|
|
778
|
+
|
|
779
|
+
def _check_huggingface_text_generation_model(
|
|
780
|
+
self,
|
|
781
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
782
|
+
) -> None:
|
|
783
|
+
"""Check if the model is a HuggingFace pipeline with text-generation task.
|
|
784
|
+
|
|
785
|
+
Args:
|
|
786
|
+
statement_params: Optional dictionary of statement parameters to include
|
|
787
|
+
in the SQL command to fetch model spec.
|
|
788
|
+
|
|
789
|
+
Raises:
|
|
790
|
+
ValueError: If the model is not a HuggingFace text-generation model.
|
|
791
|
+
"""
|
|
792
|
+
# Fetch model spec
|
|
793
|
+
model_spec = self._model_ops._fetch_model_spec(
|
|
794
|
+
database_name=None,
|
|
795
|
+
schema_name=None,
|
|
796
|
+
model_name=self._model_name,
|
|
797
|
+
version_name=self._version_name,
|
|
798
|
+
statement_params=statement_params,
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
# Check if model_type is huggingface_pipeline
|
|
802
|
+
model_type = model_spec.get("model_type")
|
|
803
|
+
if model_type != "huggingface_pipeline":
|
|
804
|
+
raise ValueError(
|
|
805
|
+
f"Inference engine is only supported for HuggingFace text-generation models. "
|
|
806
|
+
f"Found model_type: {model_type}"
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
# Check if model supports text-generation task
|
|
810
|
+
# There should only be one model in the list because we don't support multiple models in a single model spec
|
|
811
|
+
models = model_spec.get("models", {})
|
|
812
|
+
is_text_generation = False
|
|
813
|
+
found_tasks: list[str] = []
|
|
814
|
+
|
|
815
|
+
# As long as the model supports text-generation task, we can use it
|
|
816
|
+
for _, model_info in models.items():
|
|
817
|
+
options = model_info.get("options", {})
|
|
818
|
+
task = options.get("task")
|
|
819
|
+
if task:
|
|
820
|
+
found_tasks.append(str(task))
|
|
821
|
+
if task == "text-generation":
|
|
822
|
+
is_text_generation = True
|
|
823
|
+
break
|
|
824
|
+
|
|
825
|
+
if not is_text_generation:
|
|
826
|
+
tasks_str = ", ".join(found_tasks)
|
|
827
|
+
found_tasks_str = (
|
|
828
|
+
f"Found task(s): {tasks_str} in model spec." if found_tasks else "No task found in model spec."
|
|
829
|
+
)
|
|
830
|
+
raise ValueError(f"Inference engine is only supported for task 'text-generation'. {found_tasks_str}")
|
|
831
|
+
|
|
710
832
|
@overload
|
|
711
833
|
def create_service(
|
|
712
834
|
self,
|
|
@@ -714,7 +836,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
714
836
|
service_name: str,
|
|
715
837
|
image_build_compute_pool: Optional[str] = None,
|
|
716
838
|
service_compute_pool: str,
|
|
717
|
-
image_repo: str,
|
|
839
|
+
image_repo: Optional[str] = None,
|
|
718
840
|
ingress_enabled: bool = False,
|
|
719
841
|
max_instances: int = 1,
|
|
720
842
|
cpu_requests: Optional[str] = None,
|
|
@@ -725,6 +847,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
725
847
|
force_rebuild: bool = False,
|
|
726
848
|
build_external_access_integration: Optional[str] = None,
|
|
727
849
|
block: bool = True,
|
|
850
|
+
experimental_options: Optional[dict[str, Any]] = None,
|
|
728
851
|
) -> Union[str, async_job.AsyncJob]:
|
|
729
852
|
"""Create an inference service with the given spec.
|
|
730
853
|
|
|
@@ -735,7 +858,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
735
858
|
the service compute pool if None.
|
|
736
859
|
service_compute_pool: The name of the compute pool used to run the inference service.
|
|
737
860
|
image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
|
|
738
|
-
or schema of the model will be used.
|
|
861
|
+
or schema of the model will be used. This can be None, in that case a default hidden image repository
|
|
862
|
+
will be used.
|
|
739
863
|
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
|
740
864
|
BIND SERVICE ENDPOINT privilege on the account.
|
|
741
865
|
max_instances: The maximum number of inference service instances to run. The same value it set to
|
|
@@ -756,6 +880,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
756
880
|
block: A bool value indicating whether this function will wait until the service is available.
|
|
757
881
|
When it is ``False``, this function executes the underlying service creation asynchronously
|
|
758
882
|
and returns an :class:`AsyncJob`.
|
|
883
|
+
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
884
|
+
Currently, only `inference_engine` and `inference_engine_args_override` are supported.
|
|
885
|
+
`inference_engine` is the name of the inference engine to use.
|
|
886
|
+
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
759
887
|
"""
|
|
760
888
|
...
|
|
761
889
|
|
|
@@ -766,7 +894,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
766
894
|
service_name: str,
|
|
767
895
|
image_build_compute_pool: Optional[str] = None,
|
|
768
896
|
service_compute_pool: str,
|
|
769
|
-
image_repo: str,
|
|
897
|
+
image_repo: Optional[str] = None,
|
|
770
898
|
ingress_enabled: bool = False,
|
|
771
899
|
max_instances: int = 1,
|
|
772
900
|
cpu_requests: Optional[str] = None,
|
|
@@ -777,6 +905,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
777
905
|
force_rebuild: bool = False,
|
|
778
906
|
build_external_access_integrations: Optional[list[str]] = None,
|
|
779
907
|
block: bool = True,
|
|
908
|
+
experimental_options: Optional[dict[str, Any]] = None,
|
|
780
909
|
) -> Union[str, async_job.AsyncJob]:
|
|
781
910
|
"""Create an inference service with the given spec.
|
|
782
911
|
|
|
@@ -787,7 +916,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
787
916
|
the service compute pool if None.
|
|
788
917
|
service_compute_pool: The name of the compute pool used to run the inference service.
|
|
789
918
|
image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
|
|
790
|
-
or schema of the model will be used.
|
|
919
|
+
or schema of the model will be used. This can be None, in that case a default hidden image repository
|
|
920
|
+
will be used.
|
|
791
921
|
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
|
792
922
|
BIND SERVICE ENDPOINT privilege on the account.
|
|
793
923
|
max_instances: The maximum number of inference service instances to run. The same value it set to
|
|
@@ -808,6 +938,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
808
938
|
block: A bool value indicating whether this function will wait until the service is available.
|
|
809
939
|
When it is ``False``, this function executes the underlying service creation asynchronously
|
|
810
940
|
and returns an :class:`AsyncJob`.
|
|
941
|
+
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
942
|
+
Currently, only `inference_engine` and `inference_engine_args_override` are supported.
|
|
943
|
+
`inference_engine` is the name of the inference engine to use.
|
|
944
|
+
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
811
945
|
"""
|
|
812
946
|
...
|
|
813
947
|
|
|
@@ -832,7 +966,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
832
966
|
service_name: str,
|
|
833
967
|
image_build_compute_pool: Optional[str] = None,
|
|
834
968
|
service_compute_pool: str,
|
|
835
|
-
image_repo: str,
|
|
969
|
+
image_repo: Optional[str] = None,
|
|
836
970
|
ingress_enabled: bool = False,
|
|
837
971
|
max_instances: int = 1,
|
|
838
972
|
cpu_requests: Optional[str] = None,
|
|
@@ -844,6 +978,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
844
978
|
build_external_access_integration: Optional[str] = None,
|
|
845
979
|
build_external_access_integrations: Optional[list[str]] = None,
|
|
846
980
|
block: bool = True,
|
|
981
|
+
experimental_options: Optional[dict[str, Any]] = None,
|
|
847
982
|
) -> Union[str, async_job.AsyncJob]:
|
|
848
983
|
"""Create an inference service with the given spec.
|
|
849
984
|
|
|
@@ -854,7 +989,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
854
989
|
the service compute pool if None.
|
|
855
990
|
service_compute_pool: The name of the compute pool used to run the inference service.
|
|
856
991
|
image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
|
|
857
|
-
or schema of the model will be used.
|
|
992
|
+
or schema of the model will be used. This can be None, in that case a default hidden image repository
|
|
993
|
+
will be used.
|
|
858
994
|
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
|
859
995
|
BIND SERVICE ENDPOINT privilege on the account.
|
|
860
996
|
max_instances: The maximum number of inference service instances to run. The same value it set to
|
|
@@ -877,6 +1013,11 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
877
1013
|
block: A bool value indicating whether this function will wait until the service is available.
|
|
878
1014
|
When it is False, this function executes the underlying service creation asynchronously
|
|
879
1015
|
and returns an AsyncJob.
|
|
1016
|
+
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
1017
|
+
Currently, only `inference_engine` and `inference_engine_args_override` are supported.
|
|
1018
|
+
`inference_engine` is the name of the inference engine to use.
|
|
1019
|
+
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1020
|
+
|
|
880
1021
|
|
|
881
1022
|
Raises:
|
|
882
1023
|
ValueError: Illegal external access integration arguments.
|
|
@@ -885,6 +1026,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
885
1026
|
Returns:
|
|
886
1027
|
If `block=True`, return result information about service creation from server.
|
|
887
1028
|
Otherwise, return the service creation AsyncJob.
|
|
1029
|
+
|
|
1030
|
+
Raises:
|
|
1031
|
+
ValueError: Illegal external access integration arguments.
|
|
888
1032
|
"""
|
|
889
1033
|
statement_params = telemetry.get_statement_params(
|
|
890
1034
|
project=_TELEMETRY_PROJECT,
|
|
@@ -906,7 +1050,18 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
906
1050
|
build_external_access_integrations = [build_external_access_integration]
|
|
907
1051
|
|
|
908
1052
|
service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
|
|
909
|
-
|
|
1053
|
+
|
|
1054
|
+
# Check if model is HuggingFace text-generation before doing inference engine checks
|
|
1055
|
+
if experimental_options:
|
|
1056
|
+
self._check_huggingface_text_generation_model(statement_params)
|
|
1057
|
+
|
|
1058
|
+
inference_engine_args: Optional[service_ops.InferenceEngineArgs] = self._get_inference_engine_args(
|
|
1059
|
+
experimental_options
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1062
|
+
# Enrich inference engine args if inference engine is specified
|
|
1063
|
+
if inference_engine_args is not None:
|
|
1064
|
+
inference_engine_args = self._enrich_inference_engine_args(inference_engine_args, gpu_requests)
|
|
910
1065
|
|
|
911
1066
|
from snowflake.ml.model import event_handler
|
|
912
1067
|
from snowflake.snowpark import exceptions
|
|
@@ -929,7 +1084,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
929
1084
|
else sql_identifier.SqlIdentifier(service_compute_pool)
|
|
930
1085
|
),
|
|
931
1086
|
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
|
932
|
-
|
|
1087
|
+
image_repo_name=image_repo,
|
|
933
1088
|
ingress_enabled=ingress_enabled,
|
|
934
1089
|
max_instances=max_instances,
|
|
935
1090
|
cpu_requests=cpu_requests,
|
|
@@ -946,6 +1101,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
946
1101
|
block=block,
|
|
947
1102
|
statement_params=statement_params,
|
|
948
1103
|
progress_status=status,
|
|
1104
|
+
inference_engine_args=inference_engine_args,
|
|
949
1105
|
)
|
|
950
1106
|
status.update(label="Model service created successfully", state="complete", expanded=False)
|
|
951
1107
|
return result
|
|
@@ -1039,7 +1195,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1039
1195
|
*,
|
|
1040
1196
|
job_name: str,
|
|
1041
1197
|
compute_pool: str,
|
|
1042
|
-
image_repo: str,
|
|
1198
|
+
image_repo: Optional[str] = None,
|
|
1043
1199
|
output_table_name: str,
|
|
1044
1200
|
function_name: Optional[str] = None,
|
|
1045
1201
|
cpu_requests: Optional[str] = None,
|
|
@@ -1074,7 +1230,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1074
1230
|
job_name=job_id,
|
|
1075
1231
|
compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
|
|
1076
1232
|
warehouse_name=sql_identifier.SqlIdentifier(warehouse),
|
|
1077
|
-
|
|
1233
|
+
image_repo_name=image_repo,
|
|
1078
1234
|
output_table_database_name=output_table_db_id,
|
|
1079
1235
|
output_table_schema_name=output_table_schema_id,
|
|
1080
1236
|
output_table_name=output_table_id,
|
|
@@ -12,7 +12,11 @@ from typing import Any, Optional, Union, cast
|
|
|
12
12
|
from snowflake import snowpark
|
|
13
13
|
from snowflake.ml._internal import file_utils, platform_capabilities as pc
|
|
14
14
|
from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
|
|
15
|
-
from snowflake.ml.model import
|
|
15
|
+
from snowflake.ml.model import (
|
|
16
|
+
inference_engine as inference_engine_module,
|
|
17
|
+
model_signature,
|
|
18
|
+
type_hints,
|
|
19
|
+
)
|
|
16
20
|
from snowflake.ml.model._client.service import model_deployment_spec
|
|
17
21
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
|
18
22
|
from snowflake.ml.model._signatures import snowpark_handler
|
|
@@ -131,6 +135,12 @@ class HFModelArgs:
|
|
|
131
135
|
warehouse: Optional[str] = None
|
|
132
136
|
|
|
133
137
|
|
|
138
|
+
@dataclasses.dataclass
|
|
139
|
+
class InferenceEngineArgs:
|
|
140
|
+
inference_engine: inference_engine_module.InferenceEngine
|
|
141
|
+
inference_engine_args_override: Optional[list[str]] = None
|
|
142
|
+
|
|
143
|
+
|
|
134
144
|
class ServiceOperator:
|
|
135
145
|
"""Service operator for container services logic."""
|
|
136
146
|
|
|
@@ -180,7 +190,7 @@ class ServiceOperator:
|
|
|
180
190
|
service_name: sql_identifier.SqlIdentifier,
|
|
181
191
|
image_build_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
182
192
|
service_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
183
|
-
|
|
193
|
+
image_repo_name: Optional[str],
|
|
184
194
|
ingress_enabled: bool,
|
|
185
195
|
max_instances: int,
|
|
186
196
|
cpu_requests: Optional[str],
|
|
@@ -195,6 +205,8 @@ class ServiceOperator:
|
|
|
195
205
|
statement_params: Optional[dict[str, Any]] = None,
|
|
196
206
|
# hf model
|
|
197
207
|
hf_model_args: Optional[HFModelArgs] = None,
|
|
208
|
+
# inference engine model
|
|
209
|
+
inference_engine_args: Optional[InferenceEngineArgs] = None,
|
|
198
210
|
) -> Union[str, async_job.AsyncJob]:
|
|
199
211
|
|
|
200
212
|
# Generate operation ID for this deployment
|
|
@@ -205,15 +217,14 @@ class ServiceOperator:
|
|
|
205
217
|
schema_name = schema_name or self._schema_name
|
|
206
218
|
|
|
207
219
|
# Fall back to the model's database and schema if not provided then to the registry's database and schema
|
|
208
|
-
service_database_name = service_database_name or database_name
|
|
209
|
-
service_schema_name = service_schema_name or schema_name
|
|
220
|
+
service_database_name = service_database_name or database_name
|
|
221
|
+
service_schema_name = service_schema_name or schema_name
|
|
210
222
|
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
|
223
|
+
image_repo_fqn = ServiceOperator._get_image_repo_fqn(image_repo_name, database_name, schema_name)
|
|
224
|
+
|
|
225
|
+
# There may be more conditions to enable image build in the future
|
|
226
|
+
# For now, we only enable image build if inference engine is not specified
|
|
227
|
+
is_enable_image_build = inference_engine_args is None
|
|
217
228
|
|
|
218
229
|
# Step 1: Preparing deployment artifacts
|
|
219
230
|
progress_status.update("preparing deployment artifacts...")
|
|
@@ -230,14 +241,15 @@ class ServiceOperator:
|
|
|
230
241
|
model_name=model_name,
|
|
231
242
|
version_name=version_name,
|
|
232
243
|
)
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
244
|
+
|
|
245
|
+
if is_enable_image_build:
|
|
246
|
+
self._model_deployment_spec.add_image_build_spec(
|
|
247
|
+
image_build_compute_pool_name=image_build_compute_pool_name,
|
|
248
|
+
fully_qualified_image_repo_name=image_repo_fqn,
|
|
249
|
+
force_rebuild=force_rebuild,
|
|
250
|
+
external_access_integrations=build_external_access_integrations,
|
|
251
|
+
)
|
|
252
|
+
|
|
241
253
|
self._model_deployment_spec.add_service_spec(
|
|
242
254
|
service_database_name=service_database_name,
|
|
243
255
|
service_schema_name=service_schema_name,
|
|
@@ -266,6 +278,13 @@ class ServiceOperator:
|
|
|
266
278
|
warehouse=hf_model_args.warehouse,
|
|
267
279
|
**(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}),
|
|
268
280
|
)
|
|
281
|
+
|
|
282
|
+
if inference_engine_args:
|
|
283
|
+
self._model_deployment_spec.add_inference_engine_spec(
|
|
284
|
+
inference_engine=inference_engine_args.inference_engine,
|
|
285
|
+
inference_engine_args=inference_engine_args.inference_engine_args_override,
|
|
286
|
+
)
|
|
287
|
+
|
|
269
288
|
spec_yaml_str_or_path = self._model_deployment_spec.save()
|
|
270
289
|
|
|
271
290
|
# Step 2: Uploading deployment artifacts
|
|
@@ -412,6 +431,29 @@ class ServiceOperator:
|
|
|
412
431
|
|
|
413
432
|
return async_job
|
|
414
433
|
|
|
434
|
+
@staticmethod
|
|
435
|
+
def _get_image_repo_fqn(
|
|
436
|
+
image_repo_name: Optional[str],
|
|
437
|
+
database_name: sql_identifier.SqlIdentifier,
|
|
438
|
+
schema_name: sql_identifier.SqlIdentifier,
|
|
439
|
+
) -> Optional[str]:
|
|
440
|
+
"""Get the fully qualified name of the image repository."""
|
|
441
|
+
if image_repo_name is None or image_repo_name.strip() == "":
|
|
442
|
+
return None
|
|
443
|
+
# Parse image repo
|
|
444
|
+
(
|
|
445
|
+
image_repo_database_name,
|
|
446
|
+
image_repo_schema_name,
|
|
447
|
+
image_repo_name,
|
|
448
|
+
) = sql_identifier.parse_fully_qualified_name(image_repo_name)
|
|
449
|
+
image_repo_database_name = image_repo_database_name or database_name
|
|
450
|
+
image_repo_schema_name = image_repo_schema_name or schema_name
|
|
451
|
+
return identifier.get_schema_level_object_identifier(
|
|
452
|
+
db=image_repo_database_name.identifier(),
|
|
453
|
+
schema=image_repo_schema_name.identifier(),
|
|
454
|
+
object_name=image_repo_name.identifier(),
|
|
455
|
+
)
|
|
456
|
+
|
|
415
457
|
def _start_service_log_streaming(
|
|
416
458
|
self,
|
|
417
459
|
async_job: snowpark.AsyncJob,
|
|
@@ -838,7 +880,7 @@ class ServiceOperator:
|
|
|
838
880
|
job_name: sql_identifier.SqlIdentifier,
|
|
839
881
|
compute_pool_name: sql_identifier.SqlIdentifier,
|
|
840
882
|
warehouse_name: sql_identifier.SqlIdentifier,
|
|
841
|
-
|
|
883
|
+
image_repo_name: Optional[str],
|
|
842
884
|
output_table_database_name: Optional[sql_identifier.SqlIdentifier],
|
|
843
885
|
output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
844
886
|
output_table_name: sql_identifier.SqlIdentifier,
|
|
@@ -859,12 +901,7 @@ class ServiceOperator:
|
|
|
859
901
|
job_database_name = job_database_name or database_name or self._database_name
|
|
860
902
|
job_schema_name = job_schema_name or schema_name or self._schema_name
|
|
861
903
|
|
|
862
|
-
|
|
863
|
-
image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
|
|
864
|
-
image_repo
|
|
865
|
-
)
|
|
866
|
-
image_repo_database_name = image_repo_database_name or database_name or self._database_name
|
|
867
|
-
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
|
904
|
+
image_repo_fqn = self._get_image_repo_fqn(image_repo_name, database_name, schema_name)
|
|
868
905
|
|
|
869
906
|
input_table_database_name = job_database_name
|
|
870
907
|
input_table_schema_name = job_schema_name
|
|
@@ -948,9 +985,7 @@ class ServiceOperator:
|
|
|
948
985
|
|
|
949
986
|
self._model_deployment_spec.add_image_build_spec(
|
|
950
987
|
image_build_compute_pool_name=compute_pool_name,
|
|
951
|
-
|
|
952
|
-
image_repo_schema_name=image_repo_schema_name,
|
|
953
|
-
image_repo_name=image_repo_name,
|
|
988
|
+
fully_qualified_image_repo_name=image_repo_fqn,
|
|
954
989
|
force_rebuild=force_rebuild,
|
|
955
990
|
external_access_integrations=build_external_access_integrations,
|
|
956
991
|
)
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import pathlib
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any, Optional, Union
|
|
4
5
|
|
|
5
6
|
import yaml
|
|
6
7
|
|
|
7
8
|
from snowflake.ml._internal.utils import identifier, sql_identifier
|
|
9
|
+
from snowflake.ml.model import inference_engine as inference_engine_module
|
|
8
10
|
from snowflake.ml.model._client.service import model_deployment_spec_schema
|
|
9
11
|
|
|
10
12
|
|
|
@@ -24,6 +26,8 @@ class ModelDeploymentSpec:
|
|
|
24
26
|
self._service: Optional[model_deployment_spec_schema.Service] = None
|
|
25
27
|
self._job: Optional[model_deployment_spec_schema.Job] = None
|
|
26
28
|
self._model_loggings: Optional[list[model_deployment_spec_schema.ModelLogging]] = None
|
|
29
|
+
# this is referring to custom inference engine spec (vllm, sglang, etc)
|
|
30
|
+
self._inference_engine_spec: Optional[model_deployment_spec_schema.InferenceEngineSpec] = None
|
|
27
31
|
self._inference_spec: dict[str, Any] = {} # Common inference spec for service/job
|
|
28
32
|
|
|
29
33
|
self.database: Optional[sql_identifier.SqlIdentifier] = None
|
|
@@ -71,10 +75,8 @@ class ModelDeploymentSpec:
|
|
|
71
75
|
|
|
72
76
|
def add_image_build_spec(
|
|
73
77
|
self,
|
|
74
|
-
image_build_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
75
|
-
|
|
76
|
-
image_repo_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
77
|
-
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
78
|
+
image_build_compute_pool_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
79
|
+
fully_qualified_image_repo_name: Optional[str] = None,
|
|
78
80
|
force_rebuild: bool = False,
|
|
79
81
|
external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]] = None,
|
|
80
82
|
) -> "ModelDeploymentSpec":
|
|
@@ -82,33 +84,29 @@ class ModelDeploymentSpec:
|
|
|
82
84
|
|
|
83
85
|
Args:
|
|
84
86
|
image_build_compute_pool_name: Compute pool for image building.
|
|
85
|
-
|
|
86
|
-
image_repo_database_name: Database name for the image repository.
|
|
87
|
-
image_repo_schema_name: Schema name for the image repository.
|
|
87
|
+
fully_qualified_image_repo_name: Fully qualified name of the image repository.
|
|
88
88
|
force_rebuild: Whether to force rebuilding the image.
|
|
89
89
|
external_access_integrations: List of external access integrations.
|
|
90
90
|
|
|
91
91
|
Returns:
|
|
92
92
|
Self for chaining.
|
|
93
93
|
"""
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
),
|
|
111
|
-
)
|
|
94
|
+
if (
|
|
95
|
+
image_build_compute_pool_name is not None
|
|
96
|
+
or fully_qualified_image_repo_name is not None
|
|
97
|
+
or force_rebuild is True
|
|
98
|
+
or external_access_integrations is not None
|
|
99
|
+
):
|
|
100
|
+
self._image_build = model_deployment_spec_schema.ImageBuild(
|
|
101
|
+
compute_pool=(
|
|
102
|
+
None if image_build_compute_pool_name is None else image_build_compute_pool_name.identifier()
|
|
103
|
+
),
|
|
104
|
+
image_repo=fully_qualified_image_repo_name,
|
|
105
|
+
force_rebuild=force_rebuild,
|
|
106
|
+
external_access_integrations=(
|
|
107
|
+
[eai.identifier() for eai in external_access_integrations] if external_access_integrations else None
|
|
108
|
+
),
|
|
109
|
+
)
|
|
112
110
|
return self
|
|
113
111
|
|
|
114
112
|
def _add_inference_spec(
|
|
@@ -363,6 +361,86 @@ class ModelDeploymentSpec:
|
|
|
363
361
|
self._model_loggings.append(model_logging)
|
|
364
362
|
return self
|
|
365
363
|
|
|
364
|
+
def add_inference_engine_spec(
|
|
365
|
+
self,
|
|
366
|
+
inference_engine: inference_engine_module.InferenceEngine,
|
|
367
|
+
inference_engine_args: Optional[list[str]] = None,
|
|
368
|
+
) -> "ModelDeploymentSpec":
|
|
369
|
+
"""Add inference engine specification. This must be called after self.add_service_spec().
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
inference_engine: Inference engine.
|
|
373
|
+
inference_engine_args: Inference engine arguments.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
Self for chaining.
|
|
377
|
+
|
|
378
|
+
Raises:
|
|
379
|
+
ValueError: If inference engine specification is called before add_service_spec().
|
|
380
|
+
ValueError: If the argument does not have a '--' prefix.
|
|
381
|
+
"""
|
|
382
|
+
# TODO: needs to eventually support job deployment spec
|
|
383
|
+
if self._service is None:
|
|
384
|
+
raise ValueError("Inference engine specification must be called after add_service_spec().")
|
|
385
|
+
|
|
386
|
+
if inference_engine_args is None:
|
|
387
|
+
inference_engine_args = []
|
|
388
|
+
|
|
389
|
+
# Validate inference engine
|
|
390
|
+
if inference_engine == inference_engine_module.InferenceEngine.VLLM:
|
|
391
|
+
# Block list for VLLM args that should not be user-configurable
|
|
392
|
+
# make this a set for faster lookup
|
|
393
|
+
block_list = {
|
|
394
|
+
"--host",
|
|
395
|
+
"--port",
|
|
396
|
+
"--allowed-headers",
|
|
397
|
+
"--api-key",
|
|
398
|
+
"--lora-modules",
|
|
399
|
+
"--prompt-adapter",
|
|
400
|
+
"--ssl-keyfile",
|
|
401
|
+
"--ssl-certfile",
|
|
402
|
+
"--ssl-ca-certs",
|
|
403
|
+
"--enable-ssl-refresh",
|
|
404
|
+
"--ssl-cert-reqs",
|
|
405
|
+
"--root-path",
|
|
406
|
+
"--middleware",
|
|
407
|
+
"--disable-frontend-multiprocessing",
|
|
408
|
+
"--enable-request-id-headers",
|
|
409
|
+
"--enable-auto-tool-choice",
|
|
410
|
+
"--tool-call-parser",
|
|
411
|
+
"--tool-parser-plugin",
|
|
412
|
+
"--log-config-file",
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
filtered_args = []
|
|
416
|
+
for arg in inference_engine_args:
|
|
417
|
+
# Check if the argument has a '--' prefix
|
|
418
|
+
if not arg.startswith("--"):
|
|
419
|
+
raise ValueError(
|
|
420
|
+
f"""The argument {arg} is not allowed for configuration in Snowflake ML's
|
|
421
|
+
{inference_engine.value} inference engine. Maybe you forgot to add '--' prefix?""",
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Filter out blocked args and warn user
|
|
425
|
+
if arg.split("=")[0] in block_list:
|
|
426
|
+
warnings.warn(
|
|
427
|
+
f"""The argument {arg} is not allowed for configuration in Snowflake ML's
|
|
428
|
+
{inference_engine.value} inference engine. It will be ignored.""",
|
|
429
|
+
UserWarning,
|
|
430
|
+
stacklevel=2,
|
|
431
|
+
)
|
|
432
|
+
else:
|
|
433
|
+
filtered_args.append(arg)
|
|
434
|
+
|
|
435
|
+
inference_engine_args = filtered_args
|
|
436
|
+
|
|
437
|
+
self._service.inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
|
|
438
|
+
# convert to string to be saved in the deployment spec
|
|
439
|
+
inference_engine_name=inference_engine.value,
|
|
440
|
+
inference_engine_args=inference_engine_args,
|
|
441
|
+
)
|
|
442
|
+
return self
|
|
443
|
+
|
|
366
444
|
def save(self) -> str:
|
|
367
445
|
"""Constructs the final deployment spec from added components and saves it.
|
|
368
446
|
|
|
@@ -377,8 +455,6 @@ class ModelDeploymentSpec:
|
|
|
377
455
|
# Validations
|
|
378
456
|
if not self._models:
|
|
379
457
|
raise ValueError("Model specification is required. Call add_model_spec().")
|
|
380
|
-
if not self._image_build:
|
|
381
|
-
raise ValueError("Image build specification is required. Call add_image_build_spec().")
|
|
382
458
|
if not self._service and not self._job:
|
|
383
459
|
raise ValueError(
|
|
384
460
|
"Either service or job specification is required. Call add_service_spec() or add_job_spec()."
|