snowflake-ml-python 1.9.2__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/service_logger.py +31 -17
- snowflake/ml/experiment/callback/keras.py +63 -0
- snowflake/ml/experiment/callback/lightgbm.py +59 -0
- snowflake/ml/experiment/callback/xgboost.py +67 -0
- snowflake/ml/experiment/utils.py +14 -0
- snowflake/ml/jobs/_utils/__init__.py +0 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/payload_utils.py +55 -21
- snowflake/ml/jobs/_utils/query_helper.py +5 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +2 -2
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -5
- snowflake/ml/jobs/_utils/spec_utils.py +41 -8
- snowflake/ml/jobs/_utils/stage_utils.py +22 -9
- snowflake/ml/jobs/_utils/types.py +5 -7
- snowflake/ml/jobs/job.py +1 -1
- snowflake/ml/jobs/manager.py +1 -13
- snowflake/ml/model/_client/model/model_version_impl.py +219 -55
- snowflake/ml/model/_client/ops/service_ops.py +230 -30
- snowflake/ml/model/_client/service/model_deployment_spec.py +103 -27
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +11 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -70
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
- snowflake/ml/model/event_handler.py +87 -18
- snowflake/ml/model/inference_engine.py +5 -0
- snowflake/ml/model/models/huggingface_pipeline.py +74 -51
- snowflake/ml/model/type_hints.py +26 -1
- snowflake/ml/registry/_manager/model_manager.py +37 -70
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +294 -0
- snowflake/ml/registry/registry.py +0 -19
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/METADATA +523 -491
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/RECORD +36 -29
- snowflake/ml/experiment/callback.py +0 -121
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -12,7 +12,11 @@ from typing import Any, Optional, Union, cast
|
|
|
12
12
|
from snowflake import snowpark
|
|
13
13
|
from snowflake.ml._internal import file_utils, platform_capabilities as pc
|
|
14
14
|
from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
|
|
15
|
-
from snowflake.ml.model import
|
|
15
|
+
from snowflake.ml.model import (
|
|
16
|
+
inference_engine as inference_engine_module,
|
|
17
|
+
model_signature,
|
|
18
|
+
type_hints,
|
|
19
|
+
)
|
|
16
20
|
from snowflake.ml.model._client.service import model_deployment_spec
|
|
17
21
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
|
18
22
|
from snowflake.ml.model._signatures import snowpark_handler
|
|
@@ -131,6 +135,12 @@ class HFModelArgs:
|
|
|
131
135
|
warehouse: Optional[str] = None
|
|
132
136
|
|
|
133
137
|
|
|
138
|
+
@dataclasses.dataclass
|
|
139
|
+
class InferenceEngineArgs:
|
|
140
|
+
inference_engine: inference_engine_module.InferenceEngine
|
|
141
|
+
inference_engine_args_override: Optional[list[str]] = None
|
|
142
|
+
|
|
143
|
+
|
|
134
144
|
class ServiceOperator:
|
|
135
145
|
"""Service operator for container services logic."""
|
|
136
146
|
|
|
@@ -180,9 +190,7 @@ class ServiceOperator:
|
|
|
180
190
|
service_name: sql_identifier.SqlIdentifier,
|
|
181
191
|
image_build_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
182
192
|
service_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
183
|
-
|
|
184
|
-
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
185
|
-
image_repo_name: sql_identifier.SqlIdentifier,
|
|
193
|
+
image_repo_name: Optional[str],
|
|
186
194
|
ingress_enabled: bool,
|
|
187
195
|
max_instances: int,
|
|
188
196
|
cpu_requests: Optional[str],
|
|
@@ -193,9 +201,12 @@ class ServiceOperator:
|
|
|
193
201
|
force_rebuild: bool,
|
|
194
202
|
build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
|
|
195
203
|
block: bool,
|
|
204
|
+
progress_status: type_hints.ProgressStatus,
|
|
196
205
|
statement_params: Optional[dict[str, Any]] = None,
|
|
197
206
|
# hf model
|
|
198
207
|
hf_model_args: Optional[HFModelArgs] = None,
|
|
208
|
+
# inference engine model
|
|
209
|
+
inference_engine_args: Optional[InferenceEngineArgs] = None,
|
|
199
210
|
) -> Union[str, async_job.AsyncJob]:
|
|
200
211
|
|
|
201
212
|
# Generate operation ID for this deployment
|
|
@@ -206,11 +217,19 @@ class ServiceOperator:
|
|
|
206
217
|
schema_name = schema_name or self._schema_name
|
|
207
218
|
|
|
208
219
|
# Fall back to the model's database and schema if not provided then to the registry's database and schema
|
|
209
|
-
service_database_name = service_database_name or database_name
|
|
210
|
-
service_schema_name = service_schema_name or schema_name
|
|
220
|
+
service_database_name = service_database_name or database_name
|
|
221
|
+
service_schema_name = service_schema_name or schema_name
|
|
222
|
+
|
|
223
|
+
image_repo_fqn = ServiceOperator._get_image_repo_fqn(image_repo_name, database_name, schema_name)
|
|
224
|
+
|
|
225
|
+
# There may be more conditions to enable image build in the future
|
|
226
|
+
# For now, we only enable image build if inference engine is not specified
|
|
227
|
+
is_enable_image_build = inference_engine_args is None
|
|
228
|
+
|
|
229
|
+
# Step 1: Preparing deployment artifacts
|
|
230
|
+
progress_status.update("preparing deployment artifacts...")
|
|
231
|
+
progress_status.increment()
|
|
211
232
|
|
|
212
|
-
image_repo_database_name = image_repo_database_name or database_name or self._database_name
|
|
213
|
-
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
|
214
233
|
if self._workspace:
|
|
215
234
|
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
|
216
235
|
else:
|
|
@@ -222,14 +241,15 @@ class ServiceOperator:
|
|
|
222
241
|
model_name=model_name,
|
|
223
242
|
version_name=version_name,
|
|
224
243
|
)
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
244
|
+
|
|
245
|
+
if is_enable_image_build:
|
|
246
|
+
self._model_deployment_spec.add_image_build_spec(
|
|
247
|
+
image_build_compute_pool_name=image_build_compute_pool_name,
|
|
248
|
+
fully_qualified_image_repo_name=image_repo_fqn,
|
|
249
|
+
force_rebuild=force_rebuild,
|
|
250
|
+
external_access_integrations=build_external_access_integrations,
|
|
251
|
+
)
|
|
252
|
+
|
|
233
253
|
self._model_deployment_spec.add_service_spec(
|
|
234
254
|
service_database_name=service_database_name,
|
|
235
255
|
service_schema_name=service_schema_name,
|
|
@@ -258,7 +278,19 @@ class ServiceOperator:
|
|
|
258
278
|
warehouse=hf_model_args.warehouse,
|
|
259
279
|
**(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}),
|
|
260
280
|
)
|
|
281
|
+
|
|
282
|
+
if inference_engine_args:
|
|
283
|
+
self._model_deployment_spec.add_inference_engine_spec(
|
|
284
|
+
inference_engine=inference_engine_args.inference_engine,
|
|
285
|
+
inference_engine_args=inference_engine_args.inference_engine_args_override,
|
|
286
|
+
)
|
|
287
|
+
|
|
261
288
|
spec_yaml_str_or_path = self._model_deployment_spec.save()
|
|
289
|
+
|
|
290
|
+
# Step 2: Uploading deployment artifacts
|
|
291
|
+
progress_status.update("uploading deployment artifacts...")
|
|
292
|
+
progress_status.increment()
|
|
293
|
+
|
|
262
294
|
if self._workspace:
|
|
263
295
|
assert stage_path is not None
|
|
264
296
|
file_utils.upload_directory_to_stage(
|
|
@@ -281,6 +313,10 @@ class ServiceOperator:
|
|
|
281
313
|
statement_params=statement_params,
|
|
282
314
|
)
|
|
283
315
|
|
|
316
|
+
# Step 3: Initiating model deployment
|
|
317
|
+
progress_status.update("initiating model deployment...")
|
|
318
|
+
progress_status.increment()
|
|
319
|
+
|
|
284
320
|
# deploy the model service
|
|
285
321
|
query_id, async_job = self._service_client.deploy_model(
|
|
286
322
|
stage_path=stage_path if self._workspace else None,
|
|
@@ -337,13 +373,86 @@ class ServiceOperator:
|
|
|
337
373
|
)
|
|
338
374
|
|
|
339
375
|
if block:
|
|
340
|
-
|
|
376
|
+
try:
|
|
377
|
+
# Step 4: Starting model build: waits for build to start
|
|
378
|
+
progress_status.update("starting model image build...")
|
|
379
|
+
progress_status.increment()
|
|
380
|
+
|
|
381
|
+
# Poll for model build to start if not using existing service
|
|
382
|
+
if not model_inference_service_exists:
|
|
383
|
+
self._wait_for_service_status(
|
|
384
|
+
model_build_service_name,
|
|
385
|
+
service_sql.ServiceStatus.RUNNING,
|
|
386
|
+
service_database_name,
|
|
387
|
+
service_schema_name,
|
|
388
|
+
async_job,
|
|
389
|
+
statement_params,
|
|
390
|
+
)
|
|
341
391
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
392
|
+
# Step 5: Building model image
|
|
393
|
+
progress_status.update("building model image...")
|
|
394
|
+
progress_status.increment()
|
|
395
|
+
|
|
396
|
+
# Poll for model build completion
|
|
397
|
+
if not model_inference_service_exists:
|
|
398
|
+
self._wait_for_service_status(
|
|
399
|
+
model_build_service_name,
|
|
400
|
+
service_sql.ServiceStatus.DONE,
|
|
401
|
+
service_database_name,
|
|
402
|
+
service_schema_name,
|
|
403
|
+
async_job,
|
|
404
|
+
statement_params,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# Step 6: Deploying model service (push complete, starting inference service)
|
|
408
|
+
progress_status.update("deploying model service...")
|
|
409
|
+
progress_status.increment()
|
|
410
|
+
|
|
411
|
+
log_thread.join()
|
|
412
|
+
|
|
413
|
+
res = cast(str, cast(list[row.Row], async_job.result())[0][0])
|
|
414
|
+
module_logger.info(f"Inference service {service_name} deployment complete: {res}")
|
|
415
|
+
return res
|
|
416
|
+
|
|
417
|
+
except RuntimeError as e:
|
|
418
|
+
# Handle service creation/deployment failures
|
|
419
|
+
error_msg = f"Model service deployment failed: {str(e)}"
|
|
420
|
+
module_logger.error(error_msg)
|
|
421
|
+
|
|
422
|
+
# Update progress status to show failure
|
|
423
|
+
progress_status.update(error_msg, state="error")
|
|
424
|
+
|
|
425
|
+
# Stop the log thread if it's running
|
|
426
|
+
if "log_thread" in locals() and log_thread.is_alive():
|
|
427
|
+
log_thread.join(timeout=5) # Give it a few seconds to finish gracefully
|
|
428
|
+
|
|
429
|
+
# Re-raise the exception to propagate the error
|
|
430
|
+
raise RuntimeError(error_msg) from e
|
|
431
|
+
|
|
432
|
+
return async_job
|
|
433
|
+
|
|
434
|
+
@staticmethod
|
|
435
|
+
def _get_image_repo_fqn(
|
|
436
|
+
image_repo_name: Optional[str],
|
|
437
|
+
database_name: sql_identifier.SqlIdentifier,
|
|
438
|
+
schema_name: sql_identifier.SqlIdentifier,
|
|
439
|
+
) -> Optional[str]:
|
|
440
|
+
"""Get the fully qualified name of the image repository."""
|
|
441
|
+
if image_repo_name is None or image_repo_name.strip() == "":
|
|
442
|
+
return None
|
|
443
|
+
# Parse image repo
|
|
444
|
+
(
|
|
445
|
+
image_repo_database_name,
|
|
446
|
+
image_repo_schema_name,
|
|
447
|
+
image_repo_name,
|
|
448
|
+
) = sql_identifier.parse_fully_qualified_name(image_repo_name)
|
|
449
|
+
image_repo_database_name = image_repo_database_name or database_name
|
|
450
|
+
image_repo_schema_name = image_repo_schema_name or schema_name
|
|
451
|
+
return identifier.get_schema_level_object_identifier(
|
|
452
|
+
db=image_repo_database_name.identifier(),
|
|
453
|
+
schema=image_repo_schema_name.identifier(),
|
|
454
|
+
object_name=image_repo_name.identifier(),
|
|
455
|
+
)
|
|
347
456
|
|
|
348
457
|
def _start_service_log_streaming(
|
|
349
458
|
self,
|
|
@@ -579,6 +688,7 @@ class ServiceOperator:
|
|
|
579
688
|
is_snowpark_sql_exception = isinstance(ex, exceptions.SnowparkSQLException)
|
|
580
689
|
contains_msg = any(msg in str(ex) for msg in ["Pending scheduling", "Waiting to start"])
|
|
581
690
|
matches_pattern = service_log_meta.service_status is None and re.search(pattern, str(ex)) is not None
|
|
691
|
+
|
|
582
692
|
if not (is_snowpark_sql_exception and (contains_msg or matches_pattern)):
|
|
583
693
|
module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
|
|
584
694
|
time.sleep(5)
|
|
@@ -618,6 +728,101 @@ class ServiceOperator:
|
|
|
618
728
|
except Exception as ex:
|
|
619
729
|
module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
|
|
620
730
|
|
|
731
|
+
def _wait_for_service_status(
|
|
732
|
+
self,
|
|
733
|
+
service_name: sql_identifier.SqlIdentifier,
|
|
734
|
+
target_status: service_sql.ServiceStatus,
|
|
735
|
+
service_database_name: Optional[sql_identifier.SqlIdentifier],
|
|
736
|
+
service_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
737
|
+
async_job: snowpark.AsyncJob,
|
|
738
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
739
|
+
timeout_minutes: int = 30,
|
|
740
|
+
) -> None:
|
|
741
|
+
"""Wait for service to reach the specified status while monitoring async job for failures.
|
|
742
|
+
|
|
743
|
+
Args:
|
|
744
|
+
service_name: The service to monitor
|
|
745
|
+
target_status: The target status to wait for
|
|
746
|
+
service_database_name: Database containing the service
|
|
747
|
+
service_schema_name: Schema containing the service
|
|
748
|
+
async_job: The async job to monitor for completion/failure
|
|
749
|
+
statement_params: SQL statement parameters
|
|
750
|
+
timeout_minutes: Maximum time to wait before timing out
|
|
751
|
+
|
|
752
|
+
Raises:
|
|
753
|
+
RuntimeError: If service fails, times out, or enters an error state
|
|
754
|
+
"""
|
|
755
|
+
start_time = time.time()
|
|
756
|
+
timeout_seconds = timeout_minutes * 60
|
|
757
|
+
service_seen_before = False
|
|
758
|
+
|
|
759
|
+
while True:
|
|
760
|
+
# Check if async job has failed (but don't return on success - we need specific service status)
|
|
761
|
+
if async_job.is_done():
|
|
762
|
+
try:
|
|
763
|
+
async_job.result()
|
|
764
|
+
# Async job completed successfully, but we're waiting for a specific service status
|
|
765
|
+
# This might mean the service completed and was cleaned up
|
|
766
|
+
module_logger.debug(
|
|
767
|
+
f"Async job completed but we're still waiting for {service_name} to reach {target_status.value}"
|
|
768
|
+
)
|
|
769
|
+
except Exception as e:
|
|
770
|
+
raise RuntimeError(f"Service deployment failed: {e}")
|
|
771
|
+
|
|
772
|
+
try:
|
|
773
|
+
statuses = self._service_client.get_service_container_statuses(
|
|
774
|
+
database_name=service_database_name,
|
|
775
|
+
schema_name=service_schema_name,
|
|
776
|
+
service_name=service_name,
|
|
777
|
+
include_message=True,
|
|
778
|
+
statement_params=statement_params,
|
|
779
|
+
)
|
|
780
|
+
|
|
781
|
+
if statuses:
|
|
782
|
+
service_seen_before = True
|
|
783
|
+
current_status = statuses[0].service_status
|
|
784
|
+
|
|
785
|
+
# Check if we've reached the target status
|
|
786
|
+
if current_status == target_status:
|
|
787
|
+
return
|
|
788
|
+
|
|
789
|
+
# Check for failure states
|
|
790
|
+
if current_status in [service_sql.ServiceStatus.FAILED, service_sql.ServiceStatus.INTERNAL_ERROR]:
|
|
791
|
+
error_msg = f"Service {service_name} failed with status {current_status.value}"
|
|
792
|
+
if statuses[0].message:
|
|
793
|
+
error_msg += f": {statuses[0].message}"
|
|
794
|
+
raise RuntimeError(error_msg)
|
|
795
|
+
|
|
796
|
+
except exceptions.SnowparkSQLException as e:
|
|
797
|
+
# Service might not exist yet - this is expected during initial deployment
|
|
798
|
+
if "does not exist" in str(e) or "002003" in str(e):
|
|
799
|
+
# If we're waiting for DONE status and we've seen the service before,
|
|
800
|
+
# it likely completed and was cleaned up
|
|
801
|
+
if target_status == service_sql.ServiceStatus.DONE and service_seen_before:
|
|
802
|
+
module_logger.debug(
|
|
803
|
+
f"Service {service_name} disappeared after being seen, "
|
|
804
|
+
f"assuming it reached {target_status.value} and was cleaned up"
|
|
805
|
+
)
|
|
806
|
+
return
|
|
807
|
+
|
|
808
|
+
module_logger.debug(f"Service {service_name} not found yet, continuing to wait...")
|
|
809
|
+
else:
|
|
810
|
+
# Re-raise unexpected SQL exceptions
|
|
811
|
+
raise RuntimeError(f"Error checking service status: {e}")
|
|
812
|
+
except Exception as e:
|
|
813
|
+
# Re-raise unexpected exceptions instead of masking them
|
|
814
|
+
raise RuntimeError(f"Unexpected error while waiting for service status: {e}")
|
|
815
|
+
|
|
816
|
+
# Check timeout
|
|
817
|
+
elapsed_time = time.time() - start_time
|
|
818
|
+
if elapsed_time > timeout_seconds:
|
|
819
|
+
raise RuntimeError(
|
|
820
|
+
f"Timeout waiting for service {service_name} to reach status {target_status.value} "
|
|
821
|
+
f"after {timeout_minutes} minutes"
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
time.sleep(2) # Poll every 2 seconds
|
|
825
|
+
|
|
621
826
|
@staticmethod
|
|
622
827
|
def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
|
|
623
828
|
"""Get the service ID through the server-side logic."""
|
|
@@ -675,9 +880,7 @@ class ServiceOperator:
|
|
|
675
880
|
job_name: sql_identifier.SqlIdentifier,
|
|
676
881
|
compute_pool_name: sql_identifier.SqlIdentifier,
|
|
677
882
|
warehouse_name: sql_identifier.SqlIdentifier,
|
|
678
|
-
|
|
679
|
-
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
680
|
-
image_repo_name: sql_identifier.SqlIdentifier,
|
|
883
|
+
image_repo_name: Optional[str],
|
|
681
884
|
output_table_database_name: Optional[sql_identifier.SqlIdentifier],
|
|
682
885
|
output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
683
886
|
output_table_name: sql_identifier.SqlIdentifier,
|
|
@@ -698,8 +901,7 @@ class ServiceOperator:
|
|
|
698
901
|
job_database_name = job_database_name or database_name or self._database_name
|
|
699
902
|
job_schema_name = job_schema_name or schema_name or self._schema_name
|
|
700
903
|
|
|
701
|
-
|
|
702
|
-
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
|
904
|
+
image_repo_fqn = self._get_image_repo_fqn(image_repo_name, database_name, schema_name)
|
|
703
905
|
|
|
704
906
|
input_table_database_name = job_database_name
|
|
705
907
|
input_table_schema_name = job_schema_name
|
|
@@ -783,9 +985,7 @@ class ServiceOperator:
|
|
|
783
985
|
|
|
784
986
|
self._model_deployment_spec.add_image_build_spec(
|
|
785
987
|
image_build_compute_pool_name=compute_pool_name,
|
|
786
|
-
|
|
787
|
-
image_repo_schema_name=image_repo_schema_name,
|
|
788
|
-
image_repo_name=image_repo_name,
|
|
988
|
+
fully_qualified_image_repo_name=image_repo_fqn,
|
|
789
989
|
force_rebuild=force_rebuild,
|
|
790
990
|
external_access_integrations=build_external_access_integrations,
|
|
791
991
|
)
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import pathlib
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any, Optional, Union
|
|
4
5
|
|
|
5
6
|
import yaml
|
|
6
7
|
|
|
7
8
|
from snowflake.ml._internal.utils import identifier, sql_identifier
|
|
9
|
+
from snowflake.ml.model import inference_engine as inference_engine_module
|
|
8
10
|
from snowflake.ml.model._client.service import model_deployment_spec_schema
|
|
9
11
|
|
|
10
12
|
|
|
@@ -24,6 +26,8 @@ class ModelDeploymentSpec:
|
|
|
24
26
|
self._service: Optional[model_deployment_spec_schema.Service] = None
|
|
25
27
|
self._job: Optional[model_deployment_spec_schema.Job] = None
|
|
26
28
|
self._model_loggings: Optional[list[model_deployment_spec_schema.ModelLogging]] = None
|
|
29
|
+
# this is referring to custom inference engine spec (vllm, sglang, etc)
|
|
30
|
+
self._inference_engine_spec: Optional[model_deployment_spec_schema.InferenceEngineSpec] = None
|
|
27
31
|
self._inference_spec: dict[str, Any] = {} # Common inference spec for service/job
|
|
28
32
|
|
|
29
33
|
self.database: Optional[sql_identifier.SqlIdentifier] = None
|
|
@@ -71,10 +75,8 @@ class ModelDeploymentSpec:
|
|
|
71
75
|
|
|
72
76
|
def add_image_build_spec(
|
|
73
77
|
self,
|
|
74
|
-
image_build_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
75
|
-
|
|
76
|
-
image_repo_database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
77
|
-
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
78
|
+
image_build_compute_pool_name: Optional[sql_identifier.SqlIdentifier] = None,
|
|
79
|
+
fully_qualified_image_repo_name: Optional[str] = None,
|
|
78
80
|
force_rebuild: bool = False,
|
|
79
81
|
external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]] = None,
|
|
80
82
|
) -> "ModelDeploymentSpec":
|
|
@@ -82,33 +84,29 @@ class ModelDeploymentSpec:
|
|
|
82
84
|
|
|
83
85
|
Args:
|
|
84
86
|
image_build_compute_pool_name: Compute pool for image building.
|
|
85
|
-
|
|
86
|
-
image_repo_database_name: Database name for the image repository.
|
|
87
|
-
image_repo_schema_name: Schema name for the image repository.
|
|
87
|
+
fully_qualified_image_repo_name: Fully qualified name of the image repository.
|
|
88
88
|
force_rebuild: Whether to force rebuilding the image.
|
|
89
89
|
external_access_integrations: List of external access integrations.
|
|
90
90
|
|
|
91
91
|
Returns:
|
|
92
92
|
Self for chaining.
|
|
93
93
|
"""
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
),
|
|
111
|
-
)
|
|
94
|
+
if (
|
|
95
|
+
image_build_compute_pool_name is not None
|
|
96
|
+
or fully_qualified_image_repo_name is not None
|
|
97
|
+
or force_rebuild is True
|
|
98
|
+
or external_access_integrations is not None
|
|
99
|
+
):
|
|
100
|
+
self._image_build = model_deployment_spec_schema.ImageBuild(
|
|
101
|
+
compute_pool=(
|
|
102
|
+
None if image_build_compute_pool_name is None else image_build_compute_pool_name.identifier()
|
|
103
|
+
),
|
|
104
|
+
image_repo=fully_qualified_image_repo_name,
|
|
105
|
+
force_rebuild=force_rebuild,
|
|
106
|
+
external_access_integrations=(
|
|
107
|
+
[eai.identifier() for eai in external_access_integrations] if external_access_integrations else None
|
|
108
|
+
),
|
|
109
|
+
)
|
|
112
110
|
return self
|
|
113
111
|
|
|
114
112
|
def _add_inference_spec(
|
|
@@ -363,6 +361,86 @@ class ModelDeploymentSpec:
|
|
|
363
361
|
self._model_loggings.append(model_logging)
|
|
364
362
|
return self
|
|
365
363
|
|
|
364
|
+
def add_inference_engine_spec(
|
|
365
|
+
self,
|
|
366
|
+
inference_engine: inference_engine_module.InferenceEngine,
|
|
367
|
+
inference_engine_args: Optional[list[str]] = None,
|
|
368
|
+
) -> "ModelDeploymentSpec":
|
|
369
|
+
"""Add inference engine specification. This must be called after self.add_service_spec().
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
inference_engine: Inference engine.
|
|
373
|
+
inference_engine_args: Inference engine arguments.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
Self for chaining.
|
|
377
|
+
|
|
378
|
+
Raises:
|
|
379
|
+
ValueError: If inference engine specification is called before add_service_spec().
|
|
380
|
+
ValueError: If the argument does not have a '--' prefix.
|
|
381
|
+
"""
|
|
382
|
+
# TODO: needs to eventually support job deployment spec
|
|
383
|
+
if self._service is None:
|
|
384
|
+
raise ValueError("Inference engine specification must be called after add_service_spec().")
|
|
385
|
+
|
|
386
|
+
if inference_engine_args is None:
|
|
387
|
+
inference_engine_args = []
|
|
388
|
+
|
|
389
|
+
# Validate inference engine
|
|
390
|
+
if inference_engine == inference_engine_module.InferenceEngine.VLLM:
|
|
391
|
+
# Block list for VLLM args that should not be user-configurable
|
|
392
|
+
# make this a set for faster lookup
|
|
393
|
+
block_list = {
|
|
394
|
+
"--host",
|
|
395
|
+
"--port",
|
|
396
|
+
"--allowed-headers",
|
|
397
|
+
"--api-key",
|
|
398
|
+
"--lora-modules",
|
|
399
|
+
"--prompt-adapter",
|
|
400
|
+
"--ssl-keyfile",
|
|
401
|
+
"--ssl-certfile",
|
|
402
|
+
"--ssl-ca-certs",
|
|
403
|
+
"--enable-ssl-refresh",
|
|
404
|
+
"--ssl-cert-reqs",
|
|
405
|
+
"--root-path",
|
|
406
|
+
"--middleware",
|
|
407
|
+
"--disable-frontend-multiprocessing",
|
|
408
|
+
"--enable-request-id-headers",
|
|
409
|
+
"--enable-auto-tool-choice",
|
|
410
|
+
"--tool-call-parser",
|
|
411
|
+
"--tool-parser-plugin",
|
|
412
|
+
"--log-config-file",
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
filtered_args = []
|
|
416
|
+
for arg in inference_engine_args:
|
|
417
|
+
# Check if the argument has a '--' prefix
|
|
418
|
+
if not arg.startswith("--"):
|
|
419
|
+
raise ValueError(
|
|
420
|
+
f"""The argument {arg} is not allowed for configuration in Snowflake ML's
|
|
421
|
+
{inference_engine.value} inference engine. Maybe you forgot to add '--' prefix?""",
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Filter out blocked args and warn user
|
|
425
|
+
if arg.split("=")[0] in block_list:
|
|
426
|
+
warnings.warn(
|
|
427
|
+
f"""The argument {arg} is not allowed for configuration in Snowflake ML's
|
|
428
|
+
{inference_engine.value} inference engine. It will be ignored.""",
|
|
429
|
+
UserWarning,
|
|
430
|
+
stacklevel=2,
|
|
431
|
+
)
|
|
432
|
+
else:
|
|
433
|
+
filtered_args.append(arg)
|
|
434
|
+
|
|
435
|
+
inference_engine_args = filtered_args
|
|
436
|
+
|
|
437
|
+
self._service.inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
|
|
438
|
+
# convert to string to be saved in the deployment spec
|
|
439
|
+
inference_engine_name=inference_engine.value,
|
|
440
|
+
inference_engine_args=inference_engine_args,
|
|
441
|
+
)
|
|
442
|
+
return self
|
|
443
|
+
|
|
366
444
|
def save(self) -> str:
|
|
367
445
|
"""Constructs the final deployment spec from added components and saves it.
|
|
368
446
|
|
|
@@ -377,8 +455,6 @@ class ModelDeploymentSpec:
|
|
|
377
455
|
# Validations
|
|
378
456
|
if not self._models:
|
|
379
457
|
raise ValueError("Model specification is required. Call add_model_spec().")
|
|
380
|
-
if not self._image_build:
|
|
381
|
-
raise ValueError("Image build specification is required. Call add_image_build_spec().")
|
|
382
458
|
if not self._service and not self._job:
|
|
383
459
|
raise ValueError(
|
|
384
460
|
"Either service or job specification is required. Call add_service_spec() or add_job_spec()."
|
|
@@ -10,10 +10,15 @@ class Model(BaseModel):
|
|
|
10
10
|
version: str
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
class InferenceEngineSpec(BaseModel):
|
|
14
|
+
inference_engine_name: str
|
|
15
|
+
inference_engine_args: Optional[list[str]] = None
|
|
16
|
+
|
|
17
|
+
|
|
13
18
|
class ImageBuild(BaseModel):
|
|
14
|
-
compute_pool: str
|
|
15
|
-
image_repo: str
|
|
16
|
-
force_rebuild: bool
|
|
19
|
+
compute_pool: Optional[str] = None
|
|
20
|
+
image_repo: Optional[str] = None
|
|
21
|
+
force_rebuild: Optional[bool] = None
|
|
17
22
|
external_access_integrations: Optional[list[str]] = None
|
|
18
23
|
|
|
19
24
|
|
|
@@ -27,6 +32,7 @@ class Service(BaseModel):
|
|
|
27
32
|
gpu: Optional[str] = None
|
|
28
33
|
num_workers: Optional[int] = None
|
|
29
34
|
max_batch_rows: Optional[int] = None
|
|
35
|
+
inference_engine_spec: Optional[InferenceEngineSpec] = None
|
|
30
36
|
|
|
31
37
|
|
|
32
38
|
class Job(BaseModel):
|
|
@@ -68,13 +74,13 @@ class ModelLogging(BaseModel):
|
|
|
68
74
|
|
|
69
75
|
class ModelServiceDeploymentSpec(BaseModel):
|
|
70
76
|
models: list[Model]
|
|
71
|
-
image_build: ImageBuild
|
|
77
|
+
image_build: Optional[ImageBuild] = None
|
|
72
78
|
service: Service
|
|
73
79
|
model_loggings: Optional[list[ModelLogging]] = None
|
|
74
80
|
|
|
75
81
|
|
|
76
82
|
class ModelJobDeploymentSpec(BaseModel):
|
|
77
83
|
models: list[Model]
|
|
78
|
-
image_build: ImageBuild
|
|
84
|
+
image_build: Optional[ImageBuild] = None
|
|
79
85
|
job: Job
|
|
80
86
|
model_loggings: Optional[list[ModelLogging]] = None
|