snowflake-ml-python 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. snowflake/ml/experiment/callback/keras.py +63 -0
  2. snowflake/ml/experiment/callback/lightgbm.py +5 -1
  3. snowflake/ml/experiment/callback/xgboost.py +5 -1
  4. snowflake/ml/jobs/_utils/__init__.py +0 -0
  5. snowflake/ml/jobs/_utils/constants.py +4 -1
  6. snowflake/ml/jobs/_utils/payload_utils.py +42 -14
  7. snowflake/ml/jobs/_utils/query_helper.py +5 -1
  8. snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
  9. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +2 -2
  10. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +3 -3
  11. snowflake/ml/jobs/_utils/spec_utils.py +41 -8
  12. snowflake/ml/jobs/_utils/stage_utils.py +22 -9
  13. snowflake/ml/jobs/_utils/types.py +5 -7
  14. snowflake/ml/jobs/job.py +1 -1
  15. snowflake/ml/jobs/manager.py +1 -13
  16. snowflake/ml/model/_client/model/model_version_impl.py +166 -10
  17. snowflake/ml/model/_client/ops/service_ops.py +63 -28
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +103 -27
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +11 -5
  20. snowflake/ml/model/_model_composer/model_composer.py +1 -70
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
  22. snowflake/ml/model/inference_engine.py +5 -0
  23. snowflake/ml/model/models/huggingface_pipeline.py +4 -3
  24. snowflake/ml/registry/_manager/model_manager.py +7 -35
  25. snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
  26. snowflake/ml/version.py +1 -1
  27. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/METADATA +23 -4
  28. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/RECORD +31 -27
  29. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/WHEEL +0 -0
  30. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/licenses/LICENSE.txt +0 -0
  31. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/top_level.txt +0 -0
@@ -707,6 +707,128 @@ class ModelVersion(lineage_node.LineageNode):
707
707
  version_name=sql_identifier.SqlIdentifier(version),
708
708
  )
709
709
 
710
+ def _get_inference_engine_args(
711
+ self, experimental_options: Optional[dict[str, Any]]
712
+ ) -> Optional[service_ops.InferenceEngineArgs]:
713
+
714
+ if not experimental_options:
715
+ return None
716
+
717
+ if "inference_engine" not in experimental_options:
718
+ raise ValueError("inference_engine is required in experimental_options")
719
+
720
+ return service_ops.InferenceEngineArgs(
721
+ inference_engine=experimental_options["inference_engine"],
722
+ inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
723
+ )
724
+
725
+ def _enrich_inference_engine_args(
726
+ self,
727
+ inference_engine_args: service_ops.InferenceEngineArgs,
728
+ gpu_requests: Optional[Union[str, int]] = None,
729
+ ) -> Optional[service_ops.InferenceEngineArgs]:
730
+ """Enrich inference engine args with model path and tensor parallelism settings.
731
+
732
+ Args:
733
+ inference_engine_args: The original inference engine args
734
+ gpu_requests: The number of GPUs requested
735
+
736
+ Returns:
737
+ Enriched inference engine args
738
+
739
+ Raises:
740
+ ValueError: Invalid gpu_requests
741
+ """
742
+ if inference_engine_args.inference_engine_args_override is None:
743
+ inference_engine_args.inference_engine_args_override = []
744
+
745
+ # Get model stage path and strip off "snow://" prefix
746
+ model_stage_path = self._model_ops.get_model_version_stage_path(
747
+ database_name=None,
748
+ schema_name=None,
749
+ model_name=self._model_name,
750
+ version_name=self._version_name,
751
+ )
752
+
753
+ # Strip "snow://" prefix
754
+ if model_stage_path.startswith("snow://"):
755
+ model_stage_path = model_stage_path.replace("snow://", "", 1)
756
+
757
+ # Always overwrite the model key by appending
758
+ inference_engine_args.inference_engine_args_override.append(f"--model={model_stage_path}")
759
+
760
+ gpu_count = None
761
+
762
+ # Set tensor-parallelism if gpu_requests is specified
763
+ if gpu_requests is not None:
764
+ # assert gpu_requests is a string or an integer before casting to int
765
+ if isinstance(gpu_requests, str) or isinstance(gpu_requests, int):
766
+ try:
767
+ gpu_count = int(gpu_requests)
768
+ except ValueError:
769
+ raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
770
+
771
+ if gpu_count is not None:
772
+ if gpu_count > 0:
773
+ inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
774
+ else:
775
+ raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
776
+
777
+ return inference_engine_args
778
+
779
+ def _check_huggingface_text_generation_model(
780
+ self,
781
+ statement_params: Optional[dict[str, Any]] = None,
782
+ ) -> None:
783
+ """Check if the model is a HuggingFace pipeline with text-generation task.
784
+
785
+ Args:
786
+ statement_params: Optional dictionary of statement parameters to include
787
+ in the SQL command to fetch model spec.
788
+
789
+ Raises:
790
+ ValueError: If the model is not a HuggingFace text-generation model.
791
+ """
792
+ # Fetch model spec
793
+ model_spec = self._model_ops._fetch_model_spec(
794
+ database_name=None,
795
+ schema_name=None,
796
+ model_name=self._model_name,
797
+ version_name=self._version_name,
798
+ statement_params=statement_params,
799
+ )
800
+
801
+ # Check if model_type is huggingface_pipeline
802
+ model_type = model_spec.get("model_type")
803
+ if model_type != "huggingface_pipeline":
804
+ raise ValueError(
805
+ f"Inference engine is only supported for HuggingFace text-generation models. "
806
+ f"Found model_type: {model_type}"
807
+ )
808
+
809
+ # Check if model supports text-generation task
810
+ # There should only be one model in the list because we don't support multiple models in a single model spec
811
+ models = model_spec.get("models", {})
812
+ is_text_generation = False
813
+ found_tasks: list[str] = []
814
+
815
+ # As long as the model supports text-generation task, we can use it
816
+ for _, model_info in models.items():
817
+ options = model_info.get("options", {})
818
+ task = options.get("task")
819
+ if task:
820
+ found_tasks.append(str(task))
821
+ if task == "text-generation":
822
+ is_text_generation = True
823
+ break
824
+
825
+ if not is_text_generation:
826
+ tasks_str = ", ".join(found_tasks)
827
+ found_tasks_str = (
828
+ f"Found task(s): {tasks_str} in model spec." if found_tasks else "No task found in model spec."
829
+ )
830
+ raise ValueError(f"Inference engine is only supported for task 'text-generation'. {found_tasks_str}")
831
+
710
832
  @overload
711
833
  def create_service(
712
834
  self,
@@ -714,7 +836,7 @@ class ModelVersion(lineage_node.LineageNode):
714
836
  service_name: str,
715
837
  image_build_compute_pool: Optional[str] = None,
716
838
  service_compute_pool: str,
717
- image_repo: str,
839
+ image_repo: Optional[str] = None,
718
840
  ingress_enabled: bool = False,
719
841
  max_instances: int = 1,
720
842
  cpu_requests: Optional[str] = None,
@@ -725,6 +847,7 @@ class ModelVersion(lineage_node.LineageNode):
725
847
  force_rebuild: bool = False,
726
848
  build_external_access_integration: Optional[str] = None,
727
849
  block: bool = True,
850
+ experimental_options: Optional[dict[str, Any]] = None,
728
851
  ) -> Union[str, async_job.AsyncJob]:
729
852
  """Create an inference service with the given spec.
730
853
 
@@ -735,7 +858,8 @@ class ModelVersion(lineage_node.LineageNode):
735
858
  the service compute pool if None.
736
859
  service_compute_pool: The name of the compute pool used to run the inference service.
737
860
  image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
738
- or schema of the model will be used.
861
+ or schema of the model will be used. This can be None, in that case a default hidden image repository
862
+ will be used.
739
863
  ingress_enabled: If true, creates an service endpoint associated with the service. User must have
740
864
  BIND SERVICE ENDPOINT privilege on the account.
741
865
  max_instances: The maximum number of inference service instances to run. The same value it set to
@@ -756,6 +880,10 @@ class ModelVersion(lineage_node.LineageNode):
756
880
  block: A bool value indicating whether this function will wait until the service is available.
757
881
  When it is ``False``, this function executes the underlying service creation asynchronously
758
882
  and returns an :class:`AsyncJob`.
883
+ experimental_options: Experimental options for the service creation with custom inference engine.
884
+ Currently, only `inference_engine` and `inference_engine_args_override` are supported.
885
+ `inference_engine` is the name of the inference engine to use.
886
+ `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
759
887
  """
760
888
  ...
761
889
 
@@ -766,7 +894,7 @@ class ModelVersion(lineage_node.LineageNode):
766
894
  service_name: str,
767
895
  image_build_compute_pool: Optional[str] = None,
768
896
  service_compute_pool: str,
769
- image_repo: str,
897
+ image_repo: Optional[str] = None,
770
898
  ingress_enabled: bool = False,
771
899
  max_instances: int = 1,
772
900
  cpu_requests: Optional[str] = None,
@@ -777,6 +905,7 @@ class ModelVersion(lineage_node.LineageNode):
777
905
  force_rebuild: bool = False,
778
906
  build_external_access_integrations: Optional[list[str]] = None,
779
907
  block: bool = True,
908
+ experimental_options: Optional[dict[str, Any]] = None,
780
909
  ) -> Union[str, async_job.AsyncJob]:
781
910
  """Create an inference service with the given spec.
782
911
 
@@ -787,7 +916,8 @@ class ModelVersion(lineage_node.LineageNode):
787
916
  the service compute pool if None.
788
917
  service_compute_pool: The name of the compute pool used to run the inference service.
789
918
  image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
790
- or schema of the model will be used.
919
+ or schema of the model will be used. This can be None, in that case a default hidden image repository
920
+ will be used.
791
921
  ingress_enabled: If true, creates an service endpoint associated with the service. User must have
792
922
  BIND SERVICE ENDPOINT privilege on the account.
793
923
  max_instances: The maximum number of inference service instances to run. The same value it set to
@@ -808,6 +938,10 @@ class ModelVersion(lineage_node.LineageNode):
808
938
  block: A bool value indicating whether this function will wait until the service is available.
809
939
  When it is ``False``, this function executes the underlying service creation asynchronously
810
940
  and returns an :class:`AsyncJob`.
941
+ experimental_options: Experimental options for the service creation with custom inference engine.
942
+ Currently, only `inference_engine` and `inference_engine_args_override` are supported.
943
+ `inference_engine` is the name of the inference engine to use.
944
+ `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
811
945
  """
812
946
  ...
813
947
 
@@ -832,7 +966,7 @@ class ModelVersion(lineage_node.LineageNode):
832
966
  service_name: str,
833
967
  image_build_compute_pool: Optional[str] = None,
834
968
  service_compute_pool: str,
835
- image_repo: str,
969
+ image_repo: Optional[str] = None,
836
970
  ingress_enabled: bool = False,
837
971
  max_instances: int = 1,
838
972
  cpu_requests: Optional[str] = None,
@@ -844,6 +978,7 @@ class ModelVersion(lineage_node.LineageNode):
844
978
  build_external_access_integration: Optional[str] = None,
845
979
  build_external_access_integrations: Optional[list[str]] = None,
846
980
  block: bool = True,
981
+ experimental_options: Optional[dict[str, Any]] = None,
847
982
  ) -> Union[str, async_job.AsyncJob]:
848
983
  """Create an inference service with the given spec.
849
984
 
@@ -854,7 +989,8 @@ class ModelVersion(lineage_node.LineageNode):
854
989
  the service compute pool if None.
855
990
  service_compute_pool: The name of the compute pool used to run the inference service.
856
991
  image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
857
- or schema of the model will be used.
992
+ or schema of the model will be used. This can be None, in that case a default hidden image repository
993
+ will be used.
858
994
  ingress_enabled: If true, creates an service endpoint associated with the service. User must have
859
995
  BIND SERVICE ENDPOINT privilege on the account.
860
996
  max_instances: The maximum number of inference service instances to run. The same value it set to
@@ -877,6 +1013,11 @@ class ModelVersion(lineage_node.LineageNode):
877
1013
  block: A bool value indicating whether this function will wait until the service is available.
878
1014
  When it is False, this function executes the underlying service creation asynchronously
879
1015
  and returns an AsyncJob.
1016
+ experimental_options: Experimental options for the service creation with custom inference engine.
1017
+ Currently, only `inference_engine` and `inference_engine_args_override` are supported.
1018
+ `inference_engine` is the name of the inference engine to use.
1019
+ `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1020
+
880
1021
 
881
1022
  Raises:
882
1023
  ValueError: Illegal external access integration arguments.
@@ -885,6 +1026,9 @@ class ModelVersion(lineage_node.LineageNode):
885
1026
  Returns:
886
1027
  If `block=True`, return result information about service creation from server.
887
1028
  Otherwise, return the service creation AsyncJob.
1029
+
1030
+ Raises:
1031
+ ValueError: Illegal external access integration arguments.
888
1032
  """
889
1033
  statement_params = telemetry.get_statement_params(
890
1034
  project=_TELEMETRY_PROJECT,
@@ -906,7 +1050,18 @@ class ModelVersion(lineage_node.LineageNode):
906
1050
  build_external_access_integrations = [build_external_access_integration]
907
1051
 
908
1052
  service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
909
- image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
1053
+
1054
+ # Check if model is HuggingFace text-generation before doing inference engine checks
1055
+ if experimental_options:
1056
+ self._check_huggingface_text_generation_model(statement_params)
1057
+
1058
+ inference_engine_args: Optional[service_ops.InferenceEngineArgs] = self._get_inference_engine_args(
1059
+ experimental_options
1060
+ )
1061
+
1062
+ # Enrich inference engine args if inference engine is specified
1063
+ if inference_engine_args is not None:
1064
+ inference_engine_args = self._enrich_inference_engine_args(inference_engine_args, gpu_requests)
910
1065
 
911
1066
  from snowflake.ml.model import event_handler
912
1067
  from snowflake.snowpark import exceptions
@@ -929,7 +1084,7 @@ class ModelVersion(lineage_node.LineageNode):
929
1084
  else sql_identifier.SqlIdentifier(service_compute_pool)
930
1085
  ),
931
1086
  service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
932
- image_repo=image_repo,
1087
+ image_repo_name=image_repo,
933
1088
  ingress_enabled=ingress_enabled,
934
1089
  max_instances=max_instances,
935
1090
  cpu_requests=cpu_requests,
@@ -946,6 +1101,7 @@ class ModelVersion(lineage_node.LineageNode):
946
1101
  block=block,
947
1102
  statement_params=statement_params,
948
1103
  progress_status=status,
1104
+ inference_engine_args=inference_engine_args,
949
1105
  )
950
1106
  status.update(label="Model service created successfully", state="complete", expanded=False)
951
1107
  return result
@@ -1039,7 +1195,7 @@ class ModelVersion(lineage_node.LineageNode):
1039
1195
  *,
1040
1196
  job_name: str,
1041
1197
  compute_pool: str,
1042
- image_repo: str,
1198
+ image_repo: Optional[str] = None,
1043
1199
  output_table_name: str,
1044
1200
  function_name: Optional[str] = None,
1045
1201
  cpu_requests: Optional[str] = None,
@@ -1074,7 +1230,7 @@ class ModelVersion(lineage_node.LineageNode):
1074
1230
  job_name=job_id,
1075
1231
  compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
1076
1232
  warehouse_name=sql_identifier.SqlIdentifier(warehouse),
1077
- image_repo=image_repo,
1233
+ image_repo_name=image_repo,
1078
1234
  output_table_database_name=output_table_db_id,
1079
1235
  output_table_schema_name=output_table_schema_id,
1080
1236
  output_table_name=output_table_id,
@@ -12,7 +12,11 @@ from typing import Any, Optional, Union, cast
12
12
  from snowflake import snowpark
13
13
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
14
14
  from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
15
- from snowflake.ml.model import model_signature, type_hints
15
+ from snowflake.ml.model import (
16
+ inference_engine as inference_engine_module,
17
+ model_signature,
18
+ type_hints,
19
+ )
16
20
  from snowflake.ml.model._client.service import model_deployment_spec
17
21
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
18
22
  from snowflake.ml.model._signatures import snowpark_handler
@@ -131,6 +135,12 @@ class HFModelArgs:
131
135
  warehouse: Optional[str] = None
132
136
 
133
137
 
138
+ @dataclasses.dataclass
139
+ class InferenceEngineArgs:
140
+ inference_engine: inference_engine_module.InferenceEngine
141
+ inference_engine_args_override: Optional[list[str]] = None
142
+
143
+
134
144
  class ServiceOperator:
135
145
  """Service operator for container services logic."""
136
146
 
@@ -180,7 +190,7 @@ class ServiceOperator:
180
190
  service_name: sql_identifier.SqlIdentifier,
181
191
  image_build_compute_pool_name: sql_identifier.SqlIdentifier,
182
192
  service_compute_pool_name: sql_identifier.SqlIdentifier,
183
- image_repo: str,
193
+ image_repo_name: Optional[str],
184
194
  ingress_enabled: bool,
185
195
  max_instances: int,
186
196
  cpu_requests: Optional[str],
@@ -195,6 +205,8 @@ class ServiceOperator:
195
205
  statement_params: Optional[dict[str, Any]] = None,
196
206
  # hf model
197
207
  hf_model_args: Optional[HFModelArgs] = None,
208
+ # inference engine model
209
+ inference_engine_args: Optional[InferenceEngineArgs] = None,
198
210
  ) -> Union[str, async_job.AsyncJob]:
199
211
 
200
212
  # Generate operation ID for this deployment
@@ -205,15 +217,14 @@ class ServiceOperator:
205
217
  schema_name = schema_name or self._schema_name
206
218
 
207
219
  # Fall back to the model's database and schema if not provided then to the registry's database and schema
208
- service_database_name = service_database_name or database_name or self._database_name
209
- service_schema_name = service_schema_name or schema_name or self._schema_name
220
+ service_database_name = service_database_name or database_name
221
+ service_schema_name = service_schema_name or schema_name
210
222
 
211
- # Parse image repo
212
- image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
213
- image_repo
214
- )
215
- image_repo_database_name = image_repo_database_name or database_name or self._database_name
216
- image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
223
+ image_repo_fqn = ServiceOperator._get_image_repo_fqn(image_repo_name, database_name, schema_name)
224
+
225
+ # There may be more conditions to enable image build in the future
226
+ # For now, we only enable image build if inference engine is not specified
227
+ is_enable_image_build = inference_engine_args is None
217
228
 
218
229
  # Step 1: Preparing deployment artifacts
219
230
  progress_status.update("preparing deployment artifacts...")
@@ -230,14 +241,15 @@ class ServiceOperator:
230
241
  model_name=model_name,
231
242
  version_name=version_name,
232
243
  )
233
- self._model_deployment_spec.add_image_build_spec(
234
- image_build_compute_pool_name=image_build_compute_pool_name,
235
- image_repo_database_name=image_repo_database_name,
236
- image_repo_schema_name=image_repo_schema_name,
237
- image_repo_name=image_repo_name,
238
- force_rebuild=force_rebuild,
239
- external_access_integrations=build_external_access_integrations,
240
- )
244
+
245
+ if is_enable_image_build:
246
+ self._model_deployment_spec.add_image_build_spec(
247
+ image_build_compute_pool_name=image_build_compute_pool_name,
248
+ fully_qualified_image_repo_name=image_repo_fqn,
249
+ force_rebuild=force_rebuild,
250
+ external_access_integrations=build_external_access_integrations,
251
+ )
252
+
241
253
  self._model_deployment_spec.add_service_spec(
242
254
  service_database_name=service_database_name,
243
255
  service_schema_name=service_schema_name,
@@ -266,6 +278,13 @@ class ServiceOperator:
266
278
  warehouse=hf_model_args.warehouse,
267
279
  **(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}),
268
280
  )
281
+
282
+ if inference_engine_args:
283
+ self._model_deployment_spec.add_inference_engine_spec(
284
+ inference_engine=inference_engine_args.inference_engine,
285
+ inference_engine_args=inference_engine_args.inference_engine_args_override,
286
+ )
287
+
269
288
  spec_yaml_str_or_path = self._model_deployment_spec.save()
270
289
 
271
290
  # Step 2: Uploading deployment artifacts
@@ -412,6 +431,29 @@ class ServiceOperator:
412
431
 
413
432
  return async_job
414
433
 
434
+ @staticmethod
435
+ def _get_image_repo_fqn(
436
+ image_repo_name: Optional[str],
437
+ database_name: sql_identifier.SqlIdentifier,
438
+ schema_name: sql_identifier.SqlIdentifier,
439
+ ) -> Optional[str]:
440
+ """Get the fully qualified name of the image repository."""
441
+ if image_repo_name is None or image_repo_name.strip() == "":
442
+ return None
443
+ # Parse image repo
444
+ (
445
+ image_repo_database_name,
446
+ image_repo_schema_name,
447
+ image_repo_name,
448
+ ) = sql_identifier.parse_fully_qualified_name(image_repo_name)
449
+ image_repo_database_name = image_repo_database_name or database_name
450
+ image_repo_schema_name = image_repo_schema_name or schema_name
451
+ return identifier.get_schema_level_object_identifier(
452
+ db=image_repo_database_name.identifier(),
453
+ schema=image_repo_schema_name.identifier(),
454
+ object_name=image_repo_name.identifier(),
455
+ )
456
+
415
457
  def _start_service_log_streaming(
416
458
  self,
417
459
  async_job: snowpark.AsyncJob,
@@ -838,7 +880,7 @@ class ServiceOperator:
838
880
  job_name: sql_identifier.SqlIdentifier,
839
881
  compute_pool_name: sql_identifier.SqlIdentifier,
840
882
  warehouse_name: sql_identifier.SqlIdentifier,
841
- image_repo: str,
883
+ image_repo_name: Optional[str],
842
884
  output_table_database_name: Optional[sql_identifier.SqlIdentifier],
843
885
  output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
844
886
  output_table_name: sql_identifier.SqlIdentifier,
@@ -859,12 +901,7 @@ class ServiceOperator:
859
901
  job_database_name = job_database_name or database_name or self._database_name
860
902
  job_schema_name = job_schema_name or schema_name or self._schema_name
861
903
 
862
- # Parse image repo
863
- image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
864
- image_repo
865
- )
866
- image_repo_database_name = image_repo_database_name or database_name or self._database_name
867
- image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
904
+ image_repo_fqn = self._get_image_repo_fqn(image_repo_name, database_name, schema_name)
868
905
 
869
906
  input_table_database_name = job_database_name
870
907
  input_table_schema_name = job_schema_name
@@ -948,9 +985,7 @@ class ServiceOperator:
948
985
 
949
986
  self._model_deployment_spec.add_image_build_spec(
950
987
  image_build_compute_pool_name=compute_pool_name,
951
- image_repo_database_name=image_repo_database_name,
952
- image_repo_schema_name=image_repo_schema_name,
953
- image_repo_name=image_repo_name,
988
+ fully_qualified_image_repo_name=image_repo_fqn,
954
989
  force_rebuild=force_rebuild,
955
990
  external_access_integrations=build_external_access_integrations,
956
991
  )
@@ -1,10 +1,12 @@
1
1
  import json
2
2
  import pathlib
3
+ import warnings
3
4
  from typing import Any, Optional, Union
4
5
 
5
6
  import yaml
6
7
 
7
8
  from snowflake.ml._internal.utils import identifier, sql_identifier
9
+ from snowflake.ml.model import inference_engine as inference_engine_module
8
10
  from snowflake.ml.model._client.service import model_deployment_spec_schema
9
11
 
10
12
 
@@ -24,6 +26,8 @@ class ModelDeploymentSpec:
24
26
  self._service: Optional[model_deployment_spec_schema.Service] = None
25
27
  self._job: Optional[model_deployment_spec_schema.Job] = None
26
28
  self._model_loggings: Optional[list[model_deployment_spec_schema.ModelLogging]] = None
29
+ # this is referring to custom inference engine spec (vllm, sglang, etc)
30
+ self._inference_engine_spec: Optional[model_deployment_spec_schema.InferenceEngineSpec] = None
27
31
  self._inference_spec: dict[str, Any] = {} # Common inference spec for service/job
28
32
 
29
33
  self.database: Optional[sql_identifier.SqlIdentifier] = None
@@ -71,10 +75,8 @@ class ModelDeploymentSpec:
71
75
 
72
76
  def add_image_build_spec(
73
77
  self,
74
- image_build_compute_pool_name: sql_identifier.SqlIdentifier,
75
- image_repo_name: sql_identifier.SqlIdentifier,
76
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier] = None,
77
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
78
+ image_build_compute_pool_name: Optional[sql_identifier.SqlIdentifier] = None,
79
+ fully_qualified_image_repo_name: Optional[str] = None,
78
80
  force_rebuild: bool = False,
79
81
  external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]] = None,
80
82
  ) -> "ModelDeploymentSpec":
@@ -82,33 +84,29 @@ class ModelDeploymentSpec:
82
84
 
83
85
  Args:
84
86
  image_build_compute_pool_name: Compute pool for image building.
85
- image_repo_name: Name of the image repository.
86
- image_repo_database_name: Database name for the image repository.
87
- image_repo_schema_name: Schema name for the image repository.
87
+ fully_qualified_image_repo_name: Fully qualified name of the image repository.
88
88
  force_rebuild: Whether to force rebuilding the image.
89
89
  external_access_integrations: List of external access integrations.
90
90
 
91
91
  Returns:
92
92
  Self for chaining.
93
93
  """
94
- saved_image_repo_database = image_repo_database_name or self.database
95
- saved_image_repo_schema = image_repo_schema_name or self.schema
96
- assert saved_image_repo_database is not None
97
- assert saved_image_repo_schema is not None
98
- fq_image_repo_name = identifier.get_schema_level_object_identifier(
99
- db=saved_image_repo_database.identifier(),
100
- schema=saved_image_repo_schema.identifier(),
101
- object_name=image_repo_name.identifier(),
102
- )
103
-
104
- self._image_build = model_deployment_spec_schema.ImageBuild(
105
- compute_pool=image_build_compute_pool_name.identifier(),
106
- image_repo=fq_image_repo_name,
107
- force_rebuild=force_rebuild,
108
- external_access_integrations=(
109
- [eai.identifier() for eai in external_access_integrations] if external_access_integrations else None
110
- ),
111
- )
94
+ if (
95
+ image_build_compute_pool_name is not None
96
+ or fully_qualified_image_repo_name is not None
97
+ or force_rebuild is True
98
+ or external_access_integrations is not None
99
+ ):
100
+ self._image_build = model_deployment_spec_schema.ImageBuild(
101
+ compute_pool=(
102
+ None if image_build_compute_pool_name is None else image_build_compute_pool_name.identifier()
103
+ ),
104
+ image_repo=fully_qualified_image_repo_name,
105
+ force_rebuild=force_rebuild,
106
+ external_access_integrations=(
107
+ [eai.identifier() for eai in external_access_integrations] if external_access_integrations else None
108
+ ),
109
+ )
112
110
  return self
113
111
 
114
112
  def _add_inference_spec(
@@ -363,6 +361,86 @@ class ModelDeploymentSpec:
363
361
  self._model_loggings.append(model_logging)
364
362
  return self
365
363
 
364
+ def add_inference_engine_spec(
365
+ self,
366
+ inference_engine: inference_engine_module.InferenceEngine,
367
+ inference_engine_args: Optional[list[str]] = None,
368
+ ) -> "ModelDeploymentSpec":
369
+ """Add inference engine specification. This must be called after self.add_service_spec().
370
+
371
+ Args:
372
+ inference_engine: Inference engine.
373
+ inference_engine_args: Inference engine arguments.
374
+
375
+ Returns:
376
+ Self for chaining.
377
+
378
+ Raises:
379
+ ValueError: If inference engine specification is called before add_service_spec().
380
+ ValueError: If the argument does not have a '--' prefix.
381
+ """
382
+ # TODO: needs to eventually support job deployment spec
383
+ if self._service is None:
384
+ raise ValueError("Inference engine specification must be called after add_service_spec().")
385
+
386
+ if inference_engine_args is None:
387
+ inference_engine_args = []
388
+
389
+ # Validate inference engine
390
+ if inference_engine == inference_engine_module.InferenceEngine.VLLM:
391
+ # Block list for VLLM args that should not be user-configurable
392
+ # make this a set for faster lookup
393
+ block_list = {
394
+ "--host",
395
+ "--port",
396
+ "--allowed-headers",
397
+ "--api-key",
398
+ "--lora-modules",
399
+ "--prompt-adapter",
400
+ "--ssl-keyfile",
401
+ "--ssl-certfile",
402
+ "--ssl-ca-certs",
403
+ "--enable-ssl-refresh",
404
+ "--ssl-cert-reqs",
405
+ "--root-path",
406
+ "--middleware",
407
+ "--disable-frontend-multiprocessing",
408
+ "--enable-request-id-headers",
409
+ "--enable-auto-tool-choice",
410
+ "--tool-call-parser",
411
+ "--tool-parser-plugin",
412
+ "--log-config-file",
413
+ }
414
+
415
+ filtered_args = []
416
+ for arg in inference_engine_args:
417
+ # Check if the argument has a '--' prefix
418
+ if not arg.startswith("--"):
419
+ raise ValueError(
420
+ f"""The argument {arg} is not allowed for configuration in Snowflake ML's
421
+ {inference_engine.value} inference engine. Maybe you forgot to add '--' prefix?""",
422
+ )
423
+
424
+ # Filter out blocked args and warn user
425
+ if arg.split("=")[0] in block_list:
426
+ warnings.warn(
427
+ f"""The argument {arg} is not allowed for configuration in Snowflake ML's
428
+ {inference_engine.value} inference engine. It will be ignored.""",
429
+ UserWarning,
430
+ stacklevel=2,
431
+ )
432
+ else:
433
+ filtered_args.append(arg)
434
+
435
+ inference_engine_args = filtered_args
436
+
437
+ self._service.inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
438
+ # convert to string to be saved in the deployment spec
439
+ inference_engine_name=inference_engine.value,
440
+ inference_engine_args=inference_engine_args,
441
+ )
442
+ return self
443
+
366
444
  def save(self) -> str:
367
445
  """Constructs the final deployment spec from added components and saves it.
368
446
 
@@ -377,8 +455,6 @@ class ModelDeploymentSpec:
377
455
  # Validations
378
456
  if not self._models:
379
457
  raise ValueError("Model specification is required. Call add_model_spec().")
380
- if not self._image_build:
381
- raise ValueError("Image build specification is required. Call add_image_build_spec().")
382
458
  if not self._service and not self._job:
383
459
  raise ValueError(
384
460
  "Either service or job specification is required. Call add_service_spec() or add_job_spec()."