snowflake-ml-python 1.9.2__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. snowflake/ml/_internal/utils/service_logger.py +31 -17
  2. snowflake/ml/experiment/callback/keras.py +63 -0
  3. snowflake/ml/experiment/callback/lightgbm.py +59 -0
  4. snowflake/ml/experiment/callback/xgboost.py +67 -0
  5. snowflake/ml/experiment/utils.py +14 -0
  6. snowflake/ml/jobs/_utils/__init__.py +0 -0
  7. snowflake/ml/jobs/_utils/constants.py +4 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +55 -21
  9. snowflake/ml/jobs/_utils/query_helper.py +5 -1
  10. snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
  11. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +2 -2
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -5
  13. snowflake/ml/jobs/_utils/spec_utils.py +41 -8
  14. snowflake/ml/jobs/_utils/stage_utils.py +22 -9
  15. snowflake/ml/jobs/_utils/types.py +5 -7
  16. snowflake/ml/jobs/job.py +1 -1
  17. snowflake/ml/jobs/manager.py +1 -13
  18. snowflake/ml/model/_client/model/model_version_impl.py +219 -55
  19. snowflake/ml/model/_client/ops/service_ops.py +230 -30
  20. snowflake/ml/model/_client/service/model_deployment_spec.py +103 -27
  21. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +11 -5
  22. snowflake/ml/model/_model_composer/model_composer.py +1 -70
  23. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
  24. snowflake/ml/model/event_handler.py +87 -18
  25. snowflake/ml/model/inference_engine.py +5 -0
  26. snowflake/ml/model/models/huggingface_pipeline.py +74 -51
  27. snowflake/ml/model/type_hints.py +26 -1
  28. snowflake/ml/registry/_manager/model_manager.py +37 -70
  29. snowflake/ml/registry/_manager/model_parameter_reconciler.py +294 -0
  30. snowflake/ml/registry/registry.py +0 -19
  31. snowflake/ml/version.py +1 -1
  32. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/METADATA +523 -491
  33. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/RECORD +36 -29
  34. snowflake/ml/experiment/callback.py +0 -121
  35. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/WHEEL +0 -0
  36. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/licenses/LICENSE.txt +0 -0
  37. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,11 @@ from typing import Any, Optional, Union, cast
12
12
  from snowflake import snowpark
13
13
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
14
14
  from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
15
- from snowflake.ml.model import model_signature, type_hints
15
+ from snowflake.ml.model import (
16
+ inference_engine as inference_engine_module,
17
+ model_signature,
18
+ type_hints,
19
+ )
16
20
  from snowflake.ml.model._client.service import model_deployment_spec
17
21
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
18
22
  from snowflake.ml.model._signatures import snowpark_handler
@@ -131,6 +135,12 @@ class HFModelArgs:
131
135
  warehouse: Optional[str] = None
132
136
 
133
137
 
138
+ @dataclasses.dataclass
139
+ class InferenceEngineArgs:
140
+ inference_engine: inference_engine_module.InferenceEngine
141
+ inference_engine_args_override: Optional[list[str]] = None
142
+
143
+
134
144
  class ServiceOperator:
135
145
  """Service operator for container services logic."""
136
146
 
@@ -180,9 +190,7 @@ class ServiceOperator:
180
190
  service_name: sql_identifier.SqlIdentifier,
181
191
  image_build_compute_pool_name: sql_identifier.SqlIdentifier,
182
192
  service_compute_pool_name: sql_identifier.SqlIdentifier,
183
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
184
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
185
- image_repo_name: sql_identifier.SqlIdentifier,
193
+ image_repo_name: Optional[str],
186
194
  ingress_enabled: bool,
187
195
  max_instances: int,
188
196
  cpu_requests: Optional[str],
@@ -193,9 +201,12 @@ class ServiceOperator:
193
201
  force_rebuild: bool,
194
202
  build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
195
203
  block: bool,
204
+ progress_status: type_hints.ProgressStatus,
196
205
  statement_params: Optional[dict[str, Any]] = None,
197
206
  # hf model
198
207
  hf_model_args: Optional[HFModelArgs] = None,
208
+ # inference engine model
209
+ inference_engine_args: Optional[InferenceEngineArgs] = None,
199
210
  ) -> Union[str, async_job.AsyncJob]:
200
211
 
201
212
  # Generate operation ID for this deployment
@@ -206,11 +217,19 @@ class ServiceOperator:
206
217
  schema_name = schema_name or self._schema_name
207
218
 
208
219
  # Fall back to the model's database and schema if not provided then to the registry's database and schema
209
- service_database_name = service_database_name or database_name or self._database_name
210
- service_schema_name = service_schema_name or schema_name or self._schema_name
220
+ service_database_name = service_database_name or database_name
221
+ service_schema_name = service_schema_name or schema_name
222
+
223
+ image_repo_fqn = ServiceOperator._get_image_repo_fqn(image_repo_name, database_name, schema_name)
224
+
225
+ # There may be more conditions to enable image build in the future
226
+ # For now, we only enable image build if inference engine is not specified
227
+ is_enable_image_build = inference_engine_args is None
228
+
229
+ # Step 1: Preparing deployment artifacts
230
+ progress_status.update("preparing deployment artifacts...")
231
+ progress_status.increment()
211
232
 
212
- image_repo_database_name = image_repo_database_name or database_name or self._database_name
213
- image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
214
233
  if self._workspace:
215
234
  stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
216
235
  else:
@@ -222,14 +241,15 @@ class ServiceOperator:
222
241
  model_name=model_name,
223
242
  version_name=version_name,
224
243
  )
225
- self._model_deployment_spec.add_image_build_spec(
226
- image_build_compute_pool_name=image_build_compute_pool_name,
227
- image_repo_database_name=image_repo_database_name,
228
- image_repo_schema_name=image_repo_schema_name,
229
- image_repo_name=image_repo_name,
230
- force_rebuild=force_rebuild,
231
- external_access_integrations=build_external_access_integrations,
232
- )
244
+
245
+ if is_enable_image_build:
246
+ self._model_deployment_spec.add_image_build_spec(
247
+ image_build_compute_pool_name=image_build_compute_pool_name,
248
+ fully_qualified_image_repo_name=image_repo_fqn,
249
+ force_rebuild=force_rebuild,
250
+ external_access_integrations=build_external_access_integrations,
251
+ )
252
+
233
253
  self._model_deployment_spec.add_service_spec(
234
254
  service_database_name=service_database_name,
235
255
  service_schema_name=service_schema_name,
@@ -258,7 +278,19 @@ class ServiceOperator:
258
278
  warehouse=hf_model_args.warehouse,
259
279
  **(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}),
260
280
  )
281
+
282
+ if inference_engine_args:
283
+ self._model_deployment_spec.add_inference_engine_spec(
284
+ inference_engine=inference_engine_args.inference_engine,
285
+ inference_engine_args=inference_engine_args.inference_engine_args_override,
286
+ )
287
+
261
288
  spec_yaml_str_or_path = self._model_deployment_spec.save()
289
+
290
+ # Step 2: Uploading deployment artifacts
291
+ progress_status.update("uploading deployment artifacts...")
292
+ progress_status.increment()
293
+
262
294
  if self._workspace:
263
295
  assert stage_path is not None
264
296
  file_utils.upload_directory_to_stage(
@@ -281,6 +313,10 @@ class ServiceOperator:
281
313
  statement_params=statement_params,
282
314
  )
283
315
 
316
+ # Step 3: Initiating model deployment
317
+ progress_status.update("initiating model deployment...")
318
+ progress_status.increment()
319
+
284
320
  # deploy the model service
285
321
  query_id, async_job = self._service_client.deploy_model(
286
322
  stage_path=stage_path if self._workspace else None,
@@ -337,13 +373,86 @@ class ServiceOperator:
337
373
  )
338
374
 
339
375
  if block:
340
- log_thread.join()
376
+ try:
377
+ # Step 4: Starting model build: waits for build to start
378
+ progress_status.update("starting model image build...")
379
+ progress_status.increment()
380
+
381
+ # Poll for model build to start if not using existing service
382
+ if not model_inference_service_exists:
383
+ self._wait_for_service_status(
384
+ model_build_service_name,
385
+ service_sql.ServiceStatus.RUNNING,
386
+ service_database_name,
387
+ service_schema_name,
388
+ async_job,
389
+ statement_params,
390
+ )
341
391
 
342
- res = cast(str, cast(list[row.Row], async_job.result())[0][0])
343
- module_logger.info(f"Inference service {service_name} deployment complete: {res}")
344
- return res
345
- else:
346
- return async_job
392
+ # Step 5: Building model image
393
+ progress_status.update("building model image...")
394
+ progress_status.increment()
395
+
396
+ # Poll for model build completion
397
+ if not model_inference_service_exists:
398
+ self._wait_for_service_status(
399
+ model_build_service_name,
400
+ service_sql.ServiceStatus.DONE,
401
+ service_database_name,
402
+ service_schema_name,
403
+ async_job,
404
+ statement_params,
405
+ )
406
+
407
+ # Step 6: Deploying model service (push complete, starting inference service)
408
+ progress_status.update("deploying model service...")
409
+ progress_status.increment()
410
+
411
+ log_thread.join()
412
+
413
+ res = cast(str, cast(list[row.Row], async_job.result())[0][0])
414
+ module_logger.info(f"Inference service {service_name} deployment complete: {res}")
415
+ return res
416
+
417
+ except RuntimeError as e:
418
+ # Handle service creation/deployment failures
419
+ error_msg = f"Model service deployment failed: {str(e)}"
420
+ module_logger.error(error_msg)
421
+
422
+ # Update progress status to show failure
423
+ progress_status.update(error_msg, state="error")
424
+
425
+ # Stop the log thread if it's running
426
+ if "log_thread" in locals() and log_thread.is_alive():
427
+ log_thread.join(timeout=5) # Give it a few seconds to finish gracefully
428
+
429
+ # Re-raise the exception to propagate the error
430
+ raise RuntimeError(error_msg) from e
431
+
432
+ return async_job
433
+
434
+ @staticmethod
435
+ def _get_image_repo_fqn(
436
+ image_repo_name: Optional[str],
437
+ database_name: sql_identifier.SqlIdentifier,
438
+ schema_name: sql_identifier.SqlIdentifier,
439
+ ) -> Optional[str]:
440
+ """Get the fully qualified name of the image repository."""
441
+ if image_repo_name is None or image_repo_name.strip() == "":
442
+ return None
443
+ # Parse image repo
444
+ (
445
+ image_repo_database_name,
446
+ image_repo_schema_name,
447
+ image_repo_name,
448
+ ) = sql_identifier.parse_fully_qualified_name(image_repo_name)
449
+ image_repo_database_name = image_repo_database_name or database_name
450
+ image_repo_schema_name = image_repo_schema_name or schema_name
451
+ return identifier.get_schema_level_object_identifier(
452
+ db=image_repo_database_name.identifier(),
453
+ schema=image_repo_schema_name.identifier(),
454
+ object_name=image_repo_name.identifier(),
455
+ )
347
456
 
348
457
  def _start_service_log_streaming(
349
458
  self,
@@ -579,6 +688,7 @@ class ServiceOperator:
579
688
  is_snowpark_sql_exception = isinstance(ex, exceptions.SnowparkSQLException)
580
689
  contains_msg = any(msg in str(ex) for msg in ["Pending scheduling", "Waiting to start"])
581
690
  matches_pattern = service_log_meta.service_status is None and re.search(pattern, str(ex)) is not None
691
+
582
692
  if not (is_snowpark_sql_exception and (contains_msg or matches_pattern)):
583
693
  module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
584
694
  time.sleep(5)
@@ -618,6 +728,101 @@ class ServiceOperator:
618
728
  except Exception as ex:
619
729
  module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
620
730
 
731
+ def _wait_for_service_status(
732
+ self,
733
+ service_name: sql_identifier.SqlIdentifier,
734
+ target_status: service_sql.ServiceStatus,
735
+ service_database_name: Optional[sql_identifier.SqlIdentifier],
736
+ service_schema_name: Optional[sql_identifier.SqlIdentifier],
737
+ async_job: snowpark.AsyncJob,
738
+ statement_params: Optional[dict[str, Any]] = None,
739
+ timeout_minutes: int = 30,
740
+ ) -> None:
741
+ """Wait for service to reach the specified status while monitoring async job for failures.
742
+
743
+ Args:
744
+ service_name: The service to monitor
745
+ target_status: The target status to wait for
746
+ service_database_name: Database containing the service
747
+ service_schema_name: Schema containing the service
748
+ async_job: The async job to monitor for completion/failure
749
+ statement_params: SQL statement parameters
750
+ timeout_minutes: Maximum time to wait before timing out
751
+
752
+ Raises:
753
+ RuntimeError: If service fails, times out, or enters an error state
754
+ """
755
+ start_time = time.time()
756
+ timeout_seconds = timeout_minutes * 60
757
+ service_seen_before = False
758
+
759
+ while True:
760
+ # Check if async job has failed (but don't return on success - we need specific service status)
761
+ if async_job.is_done():
762
+ try:
763
+ async_job.result()
764
+ # Async job completed successfully, but we're waiting for a specific service status
765
+ # This might mean the service completed and was cleaned up
766
+ module_logger.debug(
767
+ f"Async job completed but we're still waiting for {service_name} to reach {target_status.value}"
768
+ )
769
+ except Exception as e:
770
+ raise RuntimeError(f"Service deployment failed: {e}")
771
+
772
+ try:
773
+ statuses = self._service_client.get_service_container_statuses(
774
+ database_name=service_database_name,
775
+ schema_name=service_schema_name,
776
+ service_name=service_name,
777
+ include_message=True,
778
+ statement_params=statement_params,
779
+ )
780
+
781
+ if statuses:
782
+ service_seen_before = True
783
+ current_status = statuses[0].service_status
784
+
785
+ # Check if we've reached the target status
786
+ if current_status == target_status:
787
+ return
788
+
789
+ # Check for failure states
790
+ if current_status in [service_sql.ServiceStatus.FAILED, service_sql.ServiceStatus.INTERNAL_ERROR]:
791
+ error_msg = f"Service {service_name} failed with status {current_status.value}"
792
+ if statuses[0].message:
793
+ error_msg += f": {statuses[0].message}"
794
+ raise RuntimeError(error_msg)
795
+
796
+ except exceptions.SnowparkSQLException as e:
797
+ # Service might not exist yet - this is expected during initial deployment
798
+ if "does not exist" in str(e) or "002003" in str(e):
799
+ # If we're waiting for DONE status and we've seen the service before,
800
+ # it likely completed and was cleaned up
801
+ if target_status == service_sql.ServiceStatus.DONE and service_seen_before:
802
+ module_logger.debug(
803
+ f"Service {service_name} disappeared after being seen, "
804
+ f"assuming it reached {target_status.value} and was cleaned up"
805
+ )
806
+ return
807
+
808
+ module_logger.debug(f"Service {service_name} not found yet, continuing to wait...")
809
+ else:
810
+ # Re-raise unexpected SQL exceptions
811
+ raise RuntimeError(f"Error checking service status: {e}")
812
+ except Exception as e:
813
+ # Re-raise unexpected exceptions instead of masking them
814
+ raise RuntimeError(f"Unexpected error while waiting for service status: {e}")
815
+
816
+ # Check timeout
817
+ elapsed_time = time.time() - start_time
818
+ if elapsed_time > timeout_seconds:
819
+ raise RuntimeError(
820
+ f"Timeout waiting for service {service_name} to reach status {target_status.value} "
821
+ f"after {timeout_minutes} minutes"
822
+ )
823
+
824
+ time.sleep(2) # Poll every 2 seconds
825
+
621
826
  @staticmethod
622
827
  def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
623
828
  """Get the service ID through the server-side logic."""
@@ -675,9 +880,7 @@ class ServiceOperator:
675
880
  job_name: sql_identifier.SqlIdentifier,
676
881
  compute_pool_name: sql_identifier.SqlIdentifier,
677
882
  warehouse_name: sql_identifier.SqlIdentifier,
678
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
679
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
680
- image_repo_name: sql_identifier.SqlIdentifier,
883
+ image_repo_name: Optional[str],
681
884
  output_table_database_name: Optional[sql_identifier.SqlIdentifier],
682
885
  output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
683
886
  output_table_name: sql_identifier.SqlIdentifier,
@@ -698,8 +901,7 @@ class ServiceOperator:
698
901
  job_database_name = job_database_name or database_name or self._database_name
699
902
  job_schema_name = job_schema_name or schema_name or self._schema_name
700
903
 
701
- image_repo_database_name = image_repo_database_name or database_name or self._database_name
702
- image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
904
+ image_repo_fqn = self._get_image_repo_fqn(image_repo_name, database_name, schema_name)
703
905
 
704
906
  input_table_database_name = job_database_name
705
907
  input_table_schema_name = job_schema_name
@@ -783,9 +985,7 @@ class ServiceOperator:
783
985
 
784
986
  self._model_deployment_spec.add_image_build_spec(
785
987
  image_build_compute_pool_name=compute_pool_name,
786
- image_repo_database_name=image_repo_database_name,
787
- image_repo_schema_name=image_repo_schema_name,
788
- image_repo_name=image_repo_name,
988
+ fully_qualified_image_repo_name=image_repo_fqn,
789
989
  force_rebuild=force_rebuild,
790
990
  external_access_integrations=build_external_access_integrations,
791
991
  )
@@ -1,10 +1,12 @@
1
1
  import json
2
2
  import pathlib
3
+ import warnings
3
4
  from typing import Any, Optional, Union
4
5
 
5
6
  import yaml
6
7
 
7
8
  from snowflake.ml._internal.utils import identifier, sql_identifier
9
+ from snowflake.ml.model import inference_engine as inference_engine_module
8
10
  from snowflake.ml.model._client.service import model_deployment_spec_schema
9
11
 
10
12
 
@@ -24,6 +26,8 @@ class ModelDeploymentSpec:
24
26
  self._service: Optional[model_deployment_spec_schema.Service] = None
25
27
  self._job: Optional[model_deployment_spec_schema.Job] = None
26
28
  self._model_loggings: Optional[list[model_deployment_spec_schema.ModelLogging]] = None
29
+ # this is referring to custom inference engine spec (vllm, sglang, etc)
30
+ self._inference_engine_spec: Optional[model_deployment_spec_schema.InferenceEngineSpec] = None
27
31
  self._inference_spec: dict[str, Any] = {} # Common inference spec for service/job
28
32
 
29
33
  self.database: Optional[sql_identifier.SqlIdentifier] = None
@@ -71,10 +75,8 @@ class ModelDeploymentSpec:
71
75
 
72
76
  def add_image_build_spec(
73
77
  self,
74
- image_build_compute_pool_name: sql_identifier.SqlIdentifier,
75
- image_repo_name: sql_identifier.SqlIdentifier,
76
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier] = None,
77
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
78
+ image_build_compute_pool_name: Optional[sql_identifier.SqlIdentifier] = None,
79
+ fully_qualified_image_repo_name: Optional[str] = None,
78
80
  force_rebuild: bool = False,
79
81
  external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]] = None,
80
82
  ) -> "ModelDeploymentSpec":
@@ -82,33 +84,29 @@ class ModelDeploymentSpec:
82
84
 
83
85
  Args:
84
86
  image_build_compute_pool_name: Compute pool for image building.
85
- image_repo_name: Name of the image repository.
86
- image_repo_database_name: Database name for the image repository.
87
- image_repo_schema_name: Schema name for the image repository.
87
+ fully_qualified_image_repo_name: Fully qualified name of the image repository.
88
88
  force_rebuild: Whether to force rebuilding the image.
89
89
  external_access_integrations: List of external access integrations.
90
90
 
91
91
  Returns:
92
92
  Self for chaining.
93
93
  """
94
- saved_image_repo_database = image_repo_database_name or self.database
95
- saved_image_repo_schema = image_repo_schema_name or self.schema
96
- assert saved_image_repo_database is not None
97
- assert saved_image_repo_schema is not None
98
- fq_image_repo_name = identifier.get_schema_level_object_identifier(
99
- db=saved_image_repo_database.identifier(),
100
- schema=saved_image_repo_schema.identifier(),
101
- object_name=image_repo_name.identifier(),
102
- )
103
-
104
- self._image_build = model_deployment_spec_schema.ImageBuild(
105
- compute_pool=image_build_compute_pool_name.identifier(),
106
- image_repo=fq_image_repo_name,
107
- force_rebuild=force_rebuild,
108
- external_access_integrations=(
109
- [eai.identifier() for eai in external_access_integrations] if external_access_integrations else None
110
- ),
111
- )
94
+ if (
95
+ image_build_compute_pool_name is not None
96
+ or fully_qualified_image_repo_name is not None
97
+ or force_rebuild is True
98
+ or external_access_integrations is not None
99
+ ):
100
+ self._image_build = model_deployment_spec_schema.ImageBuild(
101
+ compute_pool=(
102
+ None if image_build_compute_pool_name is None else image_build_compute_pool_name.identifier()
103
+ ),
104
+ image_repo=fully_qualified_image_repo_name,
105
+ force_rebuild=force_rebuild,
106
+ external_access_integrations=(
107
+ [eai.identifier() for eai in external_access_integrations] if external_access_integrations else None
108
+ ),
109
+ )
112
110
  return self
113
111
 
114
112
  def _add_inference_spec(
@@ -363,6 +361,86 @@ class ModelDeploymentSpec:
363
361
  self._model_loggings.append(model_logging)
364
362
  return self
365
363
 
364
+ def add_inference_engine_spec(
365
+ self,
366
+ inference_engine: inference_engine_module.InferenceEngine,
367
+ inference_engine_args: Optional[list[str]] = None,
368
+ ) -> "ModelDeploymentSpec":
369
+ """Add inference engine specification. This must be called after self.add_service_spec().
370
+
371
+ Args:
372
+ inference_engine: Inference engine.
373
+ inference_engine_args: Inference engine arguments.
374
+
375
+ Returns:
376
+ Self for chaining.
377
+
378
+ Raises:
379
+ ValueError: If inference engine specification is called before add_service_spec().
380
+ ValueError: If the argument does not have a '--' prefix.
381
+ """
382
+ # TODO: needs to eventually support job deployment spec
383
+ if self._service is None:
384
+ raise ValueError("Inference engine specification must be called after add_service_spec().")
385
+
386
+ if inference_engine_args is None:
387
+ inference_engine_args = []
388
+
389
+ # Validate inference engine
390
+ if inference_engine == inference_engine_module.InferenceEngine.VLLM:
391
+ # Block list for VLLM args that should not be user-configurable
392
+ # make this a set for faster lookup
393
+ block_list = {
394
+ "--host",
395
+ "--port",
396
+ "--allowed-headers",
397
+ "--api-key",
398
+ "--lora-modules",
399
+ "--prompt-adapter",
400
+ "--ssl-keyfile",
401
+ "--ssl-certfile",
402
+ "--ssl-ca-certs",
403
+ "--enable-ssl-refresh",
404
+ "--ssl-cert-reqs",
405
+ "--root-path",
406
+ "--middleware",
407
+ "--disable-frontend-multiprocessing",
408
+ "--enable-request-id-headers",
409
+ "--enable-auto-tool-choice",
410
+ "--tool-call-parser",
411
+ "--tool-parser-plugin",
412
+ "--log-config-file",
413
+ }
414
+
415
+ filtered_args = []
416
+ for arg in inference_engine_args:
417
+ # Check if the argument has a '--' prefix
418
+ if not arg.startswith("--"):
419
+ raise ValueError(
420
+ f"""The argument {arg} is not allowed for configuration in Snowflake ML's
421
+ {inference_engine.value} inference engine. Maybe you forgot to add '--' prefix?""",
422
+ )
423
+
424
+ # Filter out blocked args and warn user
425
+ if arg.split("=")[0] in block_list:
426
+ warnings.warn(
427
+ f"""The argument {arg} is not allowed for configuration in Snowflake ML's
428
+ {inference_engine.value} inference engine. It will be ignored.""",
429
+ UserWarning,
430
+ stacklevel=2,
431
+ )
432
+ else:
433
+ filtered_args.append(arg)
434
+
435
+ inference_engine_args = filtered_args
436
+
437
+ self._service.inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
438
+ # convert to string to be saved in the deployment spec
439
+ inference_engine_name=inference_engine.value,
440
+ inference_engine_args=inference_engine_args,
441
+ )
442
+ return self
443
+
366
444
  def save(self) -> str:
367
445
  """Constructs the final deployment spec from added components and saves it.
368
446
 
@@ -377,8 +455,6 @@ class ModelDeploymentSpec:
377
455
  # Validations
378
456
  if not self._models:
379
457
  raise ValueError("Model specification is required. Call add_model_spec().")
380
- if not self._image_build:
381
- raise ValueError("Image build specification is required. Call add_image_build_spec().")
382
458
  if not self._service and not self._job:
383
459
  raise ValueError(
384
460
  "Either service or job specification is required. Call add_service_spec() or add_job_spec()."
@@ -10,10 +10,15 @@ class Model(BaseModel):
10
10
  version: str
11
11
 
12
12
 
13
+ class InferenceEngineSpec(BaseModel):
14
+ inference_engine_name: str
15
+ inference_engine_args: Optional[list[str]] = None
16
+
17
+
13
18
  class ImageBuild(BaseModel):
14
- compute_pool: str
15
- image_repo: str
16
- force_rebuild: bool
19
+ compute_pool: Optional[str] = None
20
+ image_repo: Optional[str] = None
21
+ force_rebuild: Optional[bool] = None
17
22
  external_access_integrations: Optional[list[str]] = None
18
23
 
19
24
 
@@ -27,6 +32,7 @@ class Service(BaseModel):
27
32
  gpu: Optional[str] = None
28
33
  num_workers: Optional[int] = None
29
34
  max_batch_rows: Optional[int] = None
35
+ inference_engine_spec: Optional[InferenceEngineSpec] = None
30
36
 
31
37
 
32
38
  class Job(BaseModel):
@@ -68,13 +74,13 @@ class ModelLogging(BaseModel):
68
74
 
69
75
  class ModelServiceDeploymentSpec(BaseModel):
70
76
  models: list[Model]
71
- image_build: ImageBuild
77
+ image_build: Optional[ImageBuild] = None
72
78
  service: Service
73
79
  model_loggings: Optional[list[ModelLogging]] = None
74
80
 
75
81
 
76
82
  class ModelJobDeploymentSpec(BaseModel):
77
83
  models: list[Model]
78
- image_build: ImageBuild
84
+ image_build: Optional[ImageBuild] = None
79
85
  job: Job
80
86
  model_loggings: Optional[list[ModelLogging]] = None