snowflake-ml-python 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. snowflake/ml/_internal/utils/url.py +42 -0
  2. snowflake/ml/jobs/__init__.py +2 -0
  3. snowflake/ml/jobs/_utils/constants.py +2 -0
  4. snowflake/ml/jobs/_utils/payload_utils.py +38 -18
  5. snowflake/ml/jobs/_utils/query_helper.py +8 -1
  6. snowflake/ml/jobs/_utils/runtime_env_utils.py +58 -4
  7. snowflake/ml/jobs/_utils/spec_utils.py +0 -31
  8. snowflake/ml/jobs/_utils/stage_utils.py +2 -2
  9. snowflake/ml/jobs/_utils/types.py +22 -2
  10. snowflake/ml/jobs/job_definition.py +232 -0
  11. snowflake/ml/jobs/manager.py +16 -177
  12. snowflake/ml/lineage/lineage_node.py +1 -1
  13. snowflake/ml/model/__init__.py +6 -0
  14. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  15. snowflake/ml/model/_client/model/model_version_impl.py +109 -32
  16. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  17. snowflake/ml/model/_client/ops/model_ops.py +45 -2
  18. snowflake/ml/model/_client/ops/param_utils.py +124 -0
  19. snowflake/ml/model/_client/ops/service_ops.py +81 -61
  20. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +24 -9
  22. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +4 -0
  23. snowflake/ml/model/_client/sql/model_version.py +30 -6
  24. snowflake/ml/model/_client/sql/service.py +30 -29
  25. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  27. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
  28. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
  29. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
  30. snowflake/ml/model/_model_composer/model_method/model_method.py +62 -2
  31. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  32. snowflake/ml/model/_packager/model_handlers/huggingface.py +54 -10
  33. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +52 -16
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  36. snowflake/ml/model/_packager/model_packager.py +1 -1
  37. snowflake/ml/model/_signatures/core.py +85 -0
  38. snowflake/ml/model/_signatures/utils.py +55 -0
  39. snowflake/ml/model/code_path.py +104 -0
  40. snowflake/ml/model/custom_model.py +55 -13
  41. snowflake/ml/model/model_signature.py +13 -1
  42. snowflake/ml/model/openai_signatures.py +97 -0
  43. snowflake/ml/model/type_hints.py +2 -0
  44. snowflake/ml/registry/_manager/model_manager.py +230 -15
  45. snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
  46. snowflake/ml/registry/registry.py +4 -4
  47. snowflake/ml/version.py +1 -1
  48. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/METADATA +95 -1
  49. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/RECORD +52 -46
  50. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/WHEEL +0 -0
  51. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/licenses/LICENSE.txt +0 -0
  52. {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,5 @@
1
- import json
2
1
  import logging
3
- import os
4
- import pathlib
5
- import sys
6
- from pathlib import PurePath
7
2
  from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
8
- from uuid import uuid4
9
3
 
10
4
  import pandas as pd
11
5
 
@@ -13,13 +7,8 @@ from snowflake import snowpark
13
7
  from snowflake.ml._internal import telemetry
14
8
  from snowflake.ml._internal.utils import identifier
15
9
  from snowflake.ml.jobs import job as jb
16
- from snowflake.ml.jobs._utils import (
17
- constants,
18
- feature_flags,
19
- payload_utils,
20
- query_helper,
21
- types,
22
- )
10
+ from snowflake.ml.jobs._utils import query_helper
11
+ from snowflake.ml.jobs.job_definition import MLJobDefinition
23
12
  from snowflake.snowpark.context import get_active_session
24
13
  from snowflake.snowpark.exceptions import SnowparkSQLException
25
14
  from snowflake.snowpark.functions import coalesce, col, lit, when
@@ -457,7 +446,6 @@ def _submit_job(
457
446
  An object representing the submitted job.
458
447
 
459
448
  Raises:
460
- ValueError: If database or schema value(s) are invalid
461
449
  RuntimeError: If schema is not specified in session context or job submission
462
450
  """
463
451
  session = _ensure_session(session)
@@ -469,94 +457,30 @@ def _submit_job(
469
457
  )
470
458
  target_instances = max(target_instances, kwargs.pop("num_instances"))
471
459
 
472
- imports = None
473
460
  if "additional_payloads" in kwargs:
474
461
  logger.warning(
475
462
  "'additional_payloads' is deprecated and will be removed in a future release. Use 'imports' instead."
476
463
  )
477
- imports = kwargs.pop("additional_payloads")
464
+ if "imports" not in kwargs:
465
+ imports = kwargs.pop("additional_payloads", None)
466
+ kwargs.update({"imports": imports})
478
467
 
479
468
  if "runtime_environment" in kwargs:
480
469
  logger.warning("'runtime_environment' is in private preview since 1.15.0, do not use it in production.")
481
470
 
482
- # Use kwargs for less common optional parameters
483
- database = kwargs.pop("database", None)
484
- schema = kwargs.pop("schema", None)
485
- min_instances = kwargs.pop("min_instances", target_instances)
486
- pip_requirements = kwargs.pop("pip_requirements", None)
487
- external_access_integrations = kwargs.pop("external_access_integrations", None)
488
- env_vars = kwargs.pop("env_vars", None)
489
- spec_overrides = kwargs.pop("spec_overrides", None)
490
- enable_metrics = kwargs.pop("enable_metrics", True)
491
- query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
492
- imports = kwargs.pop("imports", None) or imports
493
- # if the mljob is submitted from a notebook, we use the same image tag as the notebook
494
- runtime_environment = kwargs.pop("runtime_environment", os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR, None))
495
-
496
- # Warn if there are unknown kwargs
497
- if kwargs:
498
- logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
499
-
500
- # Validate parameters
501
- if database and not schema:
502
- raise ValueError("Schema must be specified if database is specified.")
503
- if target_instances < 1:
504
- raise ValueError("target_instances must be greater than 0.")
505
- if not (0 < min_instances <= target_instances):
506
- raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
507
- if min_instances > 1:
508
- # Validate min_instances against compute pool max_nodes
509
- pool_info = jb._get_compute_pool_info(session, compute_pool)
510
- max_nodes = int(pool_info["max_nodes"])
511
- if min_instances > max_nodes:
512
- raise ValueError(
513
- f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
514
- f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
515
- )
516
-
517
- job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
518
- job_id = identifier.get_schema_level_object_identifier(database, schema, job_name)
519
- stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
520
- stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
521
- stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
522
-
523
- try:
524
- # Upload payload
525
- uploaded_payload = payload_utils.JobPayload(
526
- source, entrypoint=entrypoint, pip_requirements=pip_requirements, imports=imports
527
- ).upload(session, stage_path)
528
- except SnowparkSQLException as e:
529
- if e.sql_error_code == 90106:
530
- raise RuntimeError(
531
- "Please specify a schema, either in the session context or as a parameter in the job submission"
532
- )
533
- elif e.sql_error_code == 3001 and "schema" in str(e).lower():
534
- raise RuntimeError(
535
- "please grant privileges on schema before submitting a job, see",
536
- "https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/access-control-requirements",
537
- " for more details",
538
- ) from e
539
- raise
540
-
541
- combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
471
+ job_definition = MLJobDefinition.register(
472
+ source,
473
+ compute_pool,
474
+ stage_name,
475
+ session or get_active_session(),
476
+ entrypoint,
477
+ target_instances,
478
+ generate_suffix=True,
479
+ **kwargs,
480
+ )
542
481
 
543
482
  try:
544
- return _do_submit_job(
545
- session=session,
546
- payload=uploaded_payload,
547
- args=args,
548
- env_vars=combined_env_vars,
549
- spec_overrides=spec_overrides,
550
- compute_pool=compute_pool,
551
- job_id=job_id,
552
- external_access_integrations=external_access_integrations,
553
- query_warehouse=query_warehouse,
554
- target_instances=target_instances,
555
- min_instances=min_instances,
556
- enable_metrics=enable_metrics,
557
- use_async=True,
558
- runtime_environment=runtime_environment,
559
- )
483
+ return job_definition(*(args or []))
560
484
  except SnowparkSQLException as e:
561
485
  if e.sql_error_code == 3001 and "schema" in str(e).lower():
562
486
  raise RuntimeError(
@@ -567,91 +491,6 @@ def _submit_job(
567
491
  raise
568
492
 
569
493
 
570
- def _do_submit_job(
571
- session: snowpark.Session,
572
- payload: types.UploadedPayload,
573
- args: Optional[list[str]],
574
- env_vars: dict[str, str],
575
- spec_overrides: dict[str, Any],
576
- compute_pool: str,
577
- job_id: Optional[str] = None,
578
- external_access_integrations: Optional[list[str]] = None,
579
- query_warehouse: Optional[str] = None,
580
- target_instances: int = 1,
581
- min_instances: int = 1,
582
- enable_metrics: bool = True,
583
- use_async: bool = True,
584
- runtime_environment: Optional[str] = None,
585
- ) -> jb.MLJob[Any]:
586
- """
587
- Generate the SQL query for job submission.
588
-
589
- Args:
590
- session: The Snowpark session to use.
591
- payload: The uploaded job payload.
592
- args: Arguments to pass to the entrypoint script.
593
- env_vars: Environment variables to set in the job container.
594
- spec_overrides: Custom service specification overrides.
595
- compute_pool: The compute pool to use for job execution.
596
- job_id: The ID of the job.
597
- external_access_integrations: Optional list of external access integrations.
598
- query_warehouse: Optional query warehouse to use.
599
- target_instances: Number of instances for multi-node job.
600
- min_instances: Minimum number of instances required to start the job.
601
- enable_metrics: Whether to enable platform metrics for the job.
602
- use_async: Whether to run the job asynchronously.
603
- runtime_environment: image tag or full image URL to use for the job.
604
-
605
- Returns:
606
- The job object.
607
- """
608
- args = [(v.as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint] + (args or [])
609
- spec_options = {
610
- "STAGE_PATH": payload.stage_path.as_posix(),
611
- "ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
612
- "ARGS": args,
613
- "ENV_VARS": env_vars,
614
- "ENABLE_METRICS": enable_metrics,
615
- "SPEC_OVERRIDES": spec_overrides,
616
- }
617
- if runtime_environment:
618
- # for the image tag or full image URL, we use that directly
619
- spec_options["RUNTIME"] = runtime_environment
620
- elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
621
- # when feature flag is enabled, we get the local python version and wrap it in a dict
622
- # in system function, we can know whether it is python version or image tag or full image URL through the format
623
- spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
624
-
625
- job_options = {
626
- "EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
627
- "QUERY_WAREHOUSE": query_warehouse,
628
- "TARGET_INSTANCES": target_instances,
629
- "MIN_INSTANCES": min_instances,
630
- "ASYNC": use_async,
631
- }
632
-
633
- if feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(default=True):
634
- spec_options["ENABLE_STAGE_MOUNT_V2"] = True
635
- if payload.payload_name:
636
- job_options["GENERATE_SUFFIX"] = True
637
- job_options = {k: v for k, v in job_options.items() if v is not None}
638
-
639
- query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
640
- if job_id:
641
- database, schema, _ = identifier.parse_schema_level_object_identifier(job_id)
642
- params = [
643
- job_id
644
- if payload.payload_name is None
645
- else identifier.get_schema_level_object_identifier(database, schema, payload.payload_name) + "_",
646
- compute_pool,
647
- json.dumps(spec_options),
648
- json.dumps(job_options),
649
- ]
650
- actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
651
-
652
- return get_job(actual_job_id, session=session)
653
-
654
-
655
494
  def _ensure_session(session: Optional[snowpark.Session]) -> snowpark.Session:
656
495
  try:
657
496
  session = session or get_active_session()
@@ -7,7 +7,7 @@ from snowflake.ml._internal import telemetry
7
7
  from snowflake.ml._internal.utils import identifier, mixins
8
8
 
9
9
  if TYPE_CHECKING:
10
- from snowflake.ml import dataset
10
+ from snowflake.ml.dataset import dataset
11
11
  from snowflake.ml.feature_store import feature_view
12
12
  from snowflake.ml.model._client.model import model_version_impl
13
13
 
@@ -2,16 +2,20 @@ import sys
2
2
  import warnings
3
3
 
4
4
  from snowflake.ml.model._client.model.batch_inference_specs import (
5
+ ColumnHandlingOptions,
6
+ FileEncoding,
5
7
  JobSpec,
6
8
  OutputSpec,
7
9
  SaveMode,
8
10
  )
9
11
  from snowflake.ml.model._client.model.model_impl import Model
10
12
  from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
13
+ from snowflake.ml.model.code_path import CodePath
11
14
  from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
12
15
  from snowflake.ml.model.volatility import Volatility
13
16
 
14
17
  __all__ = [
18
+ "CodePath",
15
19
  "Model",
16
20
  "ModelVersion",
17
21
  "ExportMode",
@@ -20,6 +24,8 @@ __all__ = [
20
24
  "OutputSpec",
21
25
  "SaveMode",
22
26
  "Volatility",
27
+ "FileEncoding",
28
+ "ColumnHandlingOptions",
23
29
  ]
24
30
 
25
31
  _deprecation_warning_msg_for_3_9 = (
@@ -2,6 +2,7 @@ from enum import Enum
2
2
  from typing import Optional
3
3
 
4
4
  from pydantic import BaseModel
5
+ from typing_extensions import TypedDict
5
6
 
6
7
 
7
8
  class SaveMode(str, Enum):
@@ -18,6 +19,20 @@ class SaveMode(str, Enum):
18
19
  ERROR = "error"
19
20
 
20
21
 
22
+ class FileEncoding(str, Enum):
23
+ """The encoding of the file content that will be passed to the custom model."""
24
+
25
+ RAW_BYTES = "raw_bytes"
26
+ BASE64 = "base64"
27
+ BASE64_DATA_URL = "base64_data_url"
28
+
29
+
30
+ class ColumnHandlingOptions(TypedDict):
31
+ """Options for handling specific columns during run_batch for file I/O."""
32
+
33
+ encoding: FileEncoding
34
+
35
+
21
36
  class OutputSpec(BaseModel):
22
37
  """Specification for batch inference output.
23
38
 
@@ -74,7 +89,7 @@ class JobSpec(BaseModel):
74
89
  the memory of the node.
75
90
  gpu_requests (Optional[str]): The gpu limit for GPU based inference. Can be integer or
76
91
  string values. Use CPU if None.
77
- replicas (Optional[int]): Number of job replicas to run for high availability.
92
+ replicas (Optional[int]): Number of SPCS job nodes used for distributed inference.
78
93
  If not specified, defaults to 1 replica.
79
94
 
80
95
  Example:
@@ -30,6 +30,10 @@ _TELEMETRY_PROJECT = "MLOps"
30
30
  _TELEMETRY_SUBPROJECT = "ModelManagement"
31
31
  _BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
32
32
  _BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
33
+ VLLM_SUPPORTED_TASKS = [
34
+ "text-generation",
35
+ "image-text-to-text",
36
+ ]
33
37
 
34
38
 
35
39
  class ExportMode(enum.Enum):
@@ -495,6 +499,7 @@ class ModelVersion(lineage_node.LineageNode):
495
499
  function_name: Optional[str] = None,
496
500
  partition_column: Optional[str] = None,
497
501
  strict_input_validation: bool = False,
502
+ params: Optional[dict[str, Any]] = None,
498
503
  ) -> Union[pd.DataFrame, dataframe.DataFrame]:
499
504
  """Invoke a method in a model version object.
500
505
 
@@ -505,6 +510,8 @@ class ModelVersion(lineage_node.LineageNode):
505
510
  partition_column: The partition column name to partition by.
506
511
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
507
512
  type validation to make sure your input data won't overflow when providing to the model.
513
+ params: Optional dictionary of model inference parameters (e.g., temperature, top_k for LLMs).
514
+ These are passed as keyword arguments to the model's inference method. Defaults to None.
508
515
  """
509
516
  ...
510
517
 
@@ -516,6 +523,7 @@ class ModelVersion(lineage_node.LineageNode):
516
523
  service_name: str,
517
524
  function_name: Optional[str] = None,
518
525
  strict_input_validation: bool = False,
526
+ params: Optional[dict[str, Any]] = None,
519
527
  ) -> Union[pd.DataFrame, dataframe.DataFrame]:
520
528
  """Invoke a method in a model version object via a service.
521
529
 
@@ -525,6 +533,8 @@ class ModelVersion(lineage_node.LineageNode):
525
533
  function_name: The function name to run. It is the name used to call a function in SQL.
526
534
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
527
535
  type validation to make sure your input data won't overflow when providing to the model.
536
+ params: Optional dictionary of model inference parameters (e.g., temperature, top_k for LLMs).
537
+ These are passed as keyword arguments to the model's inference method. Defaults to None.
528
538
  """
529
539
  ...
530
540
 
@@ -541,6 +551,7 @@ class ModelVersion(lineage_node.LineageNode):
541
551
  function_name: Optional[str] = None,
542
552
  partition_column: Optional[str] = None,
543
553
  strict_input_validation: bool = False,
554
+ params: Optional[dict[str, Any]] = None,
544
555
  ) -> Union[pd.DataFrame, "dataframe.DataFrame"]:
545
556
  """Invoke a method in a model version object via the warehouse or a service.
546
557
 
@@ -552,6 +563,8 @@ class ModelVersion(lineage_node.LineageNode):
552
563
  partition_column: The partition column name to partition by.
553
564
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
554
565
  type validation to make sure your input data won't overflow when providing to the model.
566
+ params: Optional dictionary of model inference parameters (e.g., temperature, top_k for LLMs).
567
+ These are passed as keyword arguments to the model's inference method. Defaults to None.
555
568
 
556
569
  Returns:
557
570
  The prediction data. It would be the same type dataframe as your input.
@@ -582,6 +595,7 @@ class ModelVersion(lineage_node.LineageNode):
582
595
  service_name=service_name_id,
583
596
  strict_input_validation=strict_input_validation,
584
597
  statement_params=statement_params,
598
+ params=params,
585
599
  )
586
600
  else:
587
601
  manifest = self._get_model_manifest(statement_params=statement_params)
@@ -621,6 +635,7 @@ class ModelVersion(lineage_node.LineageNode):
621
635
  statement_params=statement_params,
622
636
  is_partitioned=target_function_info["is_partitioned"],
623
637
  explain_case_sensitive=explain_case_sensitive,
638
+ params=params,
624
639
  )
625
640
 
626
641
  def _determine_explain_case_sensitivity(
@@ -651,6 +666,9 @@ class ModelVersion(lineage_node.LineageNode):
651
666
  input_spec: dataframe.DataFrame,
652
667
  output_spec: batch_inference_specs.OutputSpec,
653
668
  job_spec: Optional[batch_inference_specs.JobSpec] = None,
669
+ params: Optional[dict[str, Any]] = None,
670
+ column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]] = None,
671
+ inference_engine_options: Optional[dict[str, Any]] = None,
654
672
  ) -> job.MLJob[Any]:
655
673
  """Execute batch inference on datasets as an SPCS job.
656
674
 
@@ -664,6 +682,16 @@ class ModelVersion(lineage_node.LineageNode):
664
682
  job_spec (Optional[batch_inference_specs.JobSpec]): Optional configuration for job
665
683
  execution parameters such as compute resources, worker counts, and job naming.
666
684
  If None, default values will be used.
685
+ params (Optional[dict[str, Any]]): Optional dictionary of model inference parameters
686
+ (e.g., temperature, top_k for LLMs). These are passed as keyword arguments to the
687
+ model's inference method. Defaults to None.
688
+ column_handling (Optional[dict[str, batch_inference_specs.FileEncoding]]): Optional dictionary
689
+ specifying how to handle specific columns during file I/O. Maps column names to their
690
+ file encoding configuration.
691
+ inference_engine_options: Options for the service creation with custom inference engine.
692
+ Supports `engine` and `engine_args_override`.
693
+ `engine` is the type of the inference engine to use.
694
+ `engine_args_override` is a list of string arguments to pass to the inference engine.
667
695
 
668
696
  Returns:
669
697
  job.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
@@ -722,6 +750,15 @@ class ModelVersion(lineage_node.LineageNode):
722
750
  if job_spec is None:
723
751
  job_spec = batch_inference_specs.JobSpec()
724
752
 
753
+ # Validate GPU support if GPU resources are requested
754
+ self._throw_error_if_gpu_is_not_supported(job_spec.gpu_requests, statement_params)
755
+
756
+ inference_engine_args = self._prepare_inference_engine_args(
757
+ inference_engine_options,
758
+ job_spec.gpu_requests,
759
+ statement_params,
760
+ )
761
+
725
762
  warehouse = job_spec.warehouse or self._service_ops._session.get_current_warehouse()
726
763
  if warehouse is None:
727
764
  raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
@@ -746,12 +783,14 @@ class ModelVersion(lineage_node.LineageNode):
746
783
  else:
747
784
  job_name = job_spec.job_name
748
785
 
786
+ target_function_info = self._get_function_info(function_name=job_spec.function_name)
787
+
749
788
  return self._service_ops.invoke_batch_job_method(
750
789
  # model version info
751
790
  model_name=self._model_name,
752
791
  version_name=self._version_name,
753
792
  # job spec
754
- function_name=self._get_function_info(function_name=job_spec.function_name)["target_method"],
793
+ function_name=target_function_info["target_method"],
755
794
  compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
756
795
  force_rebuild=job_spec.force_rebuild,
757
796
  image_repo_name=job_spec.image_repo,
@@ -766,10 +805,14 @@ class ModelVersion(lineage_node.LineageNode):
766
805
  # input and output
767
806
  input_stage_location=input_stage_location,
768
807
  input_file_pattern="*",
808
+ column_handling=column_handling,
809
+ params=params,
810
+ signature_params=target_function_info["signature"].params,
769
811
  output_stage_location=output_stage_location,
770
812
  completion_filename="_SUCCESS",
771
813
  # misc
772
814
  statement_params=statement_params,
815
+ inference_engine_args=inference_engine_args,
773
816
  )
774
817
 
775
818
  def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
@@ -985,20 +1028,55 @@ class ModelVersion(lineage_node.LineageNode):
985
1028
  " the `log_model` function."
986
1029
  )
987
1030
 
988
- def _check_huggingface_text_generation_model(
1031
+ def _prepare_inference_engine_args(
1032
+ self,
1033
+ inference_engine_options: Optional[dict[str, Any]],
1034
+ gpu_requests: Optional[Union[str, int]],
1035
+ statement_params: Optional[dict[str, Any]] = None,
1036
+ ) -> Optional[service_ops.InferenceEngineArgs]:
1037
+ """Prepare and validate inference engine arguments.
1038
+
1039
+ This method handles the common logic for processing inference engine options:
1040
+ 1. Parse inference engine options into InferenceEngineArgs
1041
+ 2. Validate that the model is a HuggingFace text-generation model (if inference engine is specified)
1042
+ 3. Enrich inference engine args
1043
+
1044
+ Args:
1045
+ inference_engine_options: Optional dictionary containing inference engine configuration.
1046
+ gpu_requests: GPU resource request string (e.g., "4").
1047
+ statement_params: Optional dictionary of statement parameters for SQL commands.
1048
+
1049
+ Returns:
1050
+ Prepared InferenceEngineArgs or None if no inference engine is specified.
1051
+ """
1052
+ inference_engine_args = inference_engine_utils._get_inference_engine_args(inference_engine_options)
1053
+
1054
+ if inference_engine_args is not None:
1055
+ # Validate that model is HuggingFace vLLM supported model and is logged with
1056
+ # OpenAI compatible signature.
1057
+ self._check_huggingface_vllm_supported_model(statement_params)
1058
+ # Enrich with GPU configuration
1059
+ inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
1060
+ inference_engine_args,
1061
+ gpu_requests,
1062
+ )
1063
+
1064
+ return inference_engine_args
1065
+
1066
+ def _check_huggingface_vllm_supported_model(
989
1067
  self,
990
1068
  statement_params: Optional[dict[str, Any]] = None,
991
1069
  ) -> None:
992
- """Check if the model is a HuggingFace pipeline with text-generation task
993
- and is logged with OPENAI_CHAT_SIGNATURE.
1070
+ """Check if the model is a HuggingFace pipeline with vLLM supported task
1071
+ and is logged with OpenAI compatible signature.
994
1072
 
995
1073
  Args:
996
1074
  statement_params: Optional dictionary of statement parameters to include
997
1075
  in the SQL command to fetch model spec.
998
1076
 
999
1077
  Raises:
1000
- ValueError: If the model is not a HuggingFace text-generation model or
1001
- if the model is not logged with OPENAI_CHAT_SIGNATURE.
1078
+ ValueError: If the model is not a HuggingFace vLLM supported model or
1079
+ if the model is not logged with OpenAI compatible signature.
1002
1080
  """
1003
1081
  # Fetch model spec
1004
1082
  model_spec = self._get_model_spec(statement_params)
@@ -1007,34 +1085,37 @@ class ModelVersion(lineage_node.LineageNode):
1007
1085
  model_type = model_spec.get("model_type")
1008
1086
  if model_type != "huggingface_pipeline":
1009
1087
  raise ValueError(
1010
- f"Inference engine is only supported for HuggingFace text-generation models. "
1088
+ f"Inference engine is only supported for HuggingFace vLLM supported models. "
1011
1089
  f"Found model_type: {model_type}"
1012
1090
  )
1013
1091
 
1014
- # Check if model supports text-generation task
1092
+ # Check if model supports vLLM supported task
1015
1093
  # There should only be one model in the list because we don't support multiple models in a single model spec
1016
1094
  models = model_spec.get("models", {})
1017
- is_text_generation = False
1095
+ is_vllm_supported_task = False
1018
1096
  found_tasks: list[str] = []
1019
1097
 
1020
- # As long as the model supports text-generation task, we can use it
1098
+ # As long as the model supports vLLM supported task, we can use it
1021
1099
  for _, model_info in models.items():
1022
1100
  options = model_info.get("options", {})
1023
1101
  task = options.get("task")
1024
1102
  if task:
1025
1103
  found_tasks.append(str(task))
1026
- if task == "text-generation":
1027
- is_text_generation = True
1104
+ if task in VLLM_SUPPORTED_TASKS:
1105
+ is_vllm_supported_task = True
1028
1106
  break
1029
1107
 
1030
- if not is_text_generation:
1108
+ if not is_vllm_supported_task:
1031
1109
  tasks_str = ", ".join(found_tasks)
1032
1110
  found_tasks_str = (
1033
1111
  f"Found task(s): {tasks_str} in model spec." if found_tasks else "No task found in model spec."
1034
1112
  )
1035
- raise ValueError(f"Inference engine is only supported for task 'text-generation'. {found_tasks_str}")
1113
+ supported_tasks_str = ", ".join(VLLM_SUPPORTED_TASKS)
1114
+ raise ValueError(
1115
+ f"Inference engine is only supported for vLLM supported tasks. {supported_tasks_str}. {found_tasks_str}"
1116
+ )
1036
1117
 
1037
- # Check if the model is logged with OPENAI_CHAT_SIGNATURE
1118
+ # Check if the model is logged with OpenAI compatible signature.
1038
1119
  signatures_dict = model_spec.get("signatures", {})
1039
1120
 
1040
1121
  # Deserialize signatures from model spec to ModelSignature objects for proper semantic comparison.
@@ -1042,11 +1123,16 @@ class ModelVersion(lineage_node.LineageNode):
1042
1123
  func_name: core.ModelSignature.from_dict(sig_dict) for func_name, sig_dict in signatures_dict.items()
1043
1124
  }
1044
1125
 
1045
- if deserialized_signatures != openai_signatures.OPENAI_CHAT_SIGNATURE:
1126
+ if deserialized_signatures not in [
1127
+ openai_signatures.OPENAI_CHAT_SIGNATURE,
1128
+ openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING,
1129
+ ]:
1046
1130
  raise ValueError(
1047
- "Inference engine requires the model to be logged with OPENAI_CHAT_SIGNATURE. "
1131
+ "Inference engine requires the model to be logged with openai_signatures.OPENAI_CHAT_SIGNATURE or "
1132
+ "openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING. "
1048
1133
  f"Found signatures: {signatures_dict}. "
1049
- "Please log the model with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE"
1134
+ "Please log the model again with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE or "
1135
+ "signatures=openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING"
1050
1136
  )
1051
1137
 
1052
1138
  @overload
@@ -1287,20 +1373,11 @@ class ModelVersion(lineage_node.LineageNode):
1287
1373
  # Validate GPU support if GPU resources are requested
1288
1374
  self._throw_error_if_gpu_is_not_supported(gpu_requests, statement_params)
1289
1375
 
1290
- inference_engine_args = inference_engine_utils._get_inference_engine_args(inference_engine_options)
1291
-
1292
- # Check if model is HuggingFace text-generation and is logged with
1293
- # OPENAI_CHAT_SIGNATURE before doing inference engine checks
1294
- # Only validate if inference engine is actually specified
1295
- if inference_engine_args is not None:
1296
- self._check_huggingface_text_generation_model(statement_params)
1297
-
1298
- # Enrich inference engine args if inference engine is specified
1299
- if inference_engine_args is not None:
1300
- inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
1301
- inference_engine_args,
1302
- gpu_requests,
1303
- )
1376
+ inference_engine_args = self._prepare_inference_engine_args(
1377
+ inference_engine_options,
1378
+ gpu_requests,
1379
+ statement_params,
1380
+ )
1304
1381
 
1305
1382
  from snowflake.ml.model import event_handler
1306
1383
  from snowflake.snowpark import exceptions
@@ -0,0 +1,36 @@
1
+ import enum
2
+ import hashlib
3
+ from typing import Optional
4
+
5
+
6
+ class DeploymentStep(enum.Enum):
7
+ MODEL_BUILD = ("model-build", "model_build_")
8
+ MODEL_INFERENCE = ("model-inference", None)
9
+ MODEL_LOGGING = ("model-logging", "model_logging_")
10
+
11
+ def __init__(self, container_name: str, service_name_prefix: Optional[str]) -> None:
12
+ self._container_name = container_name
13
+ self._service_name_prefix = service_name_prefix
14
+
15
+ @property
16
+ def container_name(self) -> str:
17
+ """Get the container name for the deployment step."""
18
+ return self._container_name
19
+
20
+ @property
21
+ def service_name_prefix(self) -> Optional[str]:
22
+ """Get the service name prefix for the deployment step."""
23
+ return self._service_name_prefix
24
+
25
+
26
+ def get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
27
+ """Get the service ID through the server-side logic."""
28
+ uuid = query_id.replace("-", "")
29
+ big_int = int(uuid, 16)
30
+ md5_hash = hashlib.md5(str(big_int).encode(), usedforsecurity=False).hexdigest()
31
+ identifier = md5_hash[:8]
32
+ service_name_prefix = deployment_step.service_name_prefix
33
+ if service_name_prefix is None:
34
+ # raise an exception if the service name prefix is None
35
+ raise ValueError(f"Service name prefix is {service_name_prefix} for deployment step {deployment_step}.")
36
+ return (service_name_prefix + identifier).upper()